Merge branch 'main' of github.com:matrix-org/dendrite into s7evink/complementcoverage

This commit is contained in:
Till Faelligen 2022-12-23 11:03:16 +01:00
commit dbbbd2985d
No known key found for this signature in database
GPG key ID: ACCDC9606D472758
102 changed files with 2536 additions and 455 deletions

View file

@ -331,7 +331,8 @@ jobs:
postgres: postgres
api: full-http
container:
image: matrixdotorg/sytest-dendrite:latest
# Temporary for debugging to see if this image is working better.
image: matrixdotorg/sytest-dendrite@sha256:434ad464a9f4ed3f8c3cc47200275b6ccb5c5031a8063daf4acea62be5a23c73
volumes:
- ${{ github.workspace }}:/src
- /root/.cache/go-build:/github/home/.cache/go-build

View file

@ -180,14 +180,14 @@ func startup() {
base := base.NewBaseDendrite(cfg, "Monolith")
defer base.Close() // nolint: errcheck
rsAPI := roomserver.NewInternalAPI(base)
federation := conn.CreateFederationClient(base, pSessions)
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation)
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation, rsAPI)
serverKeyAPI := &signing.YggdrasilKeys{}
keyRing := serverKeyAPI.KeyRing()
rsAPI := roomserver.NewInternalAPI(base)
userAPI := userapi.NewInternalAPI(base, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient())
keyAPI.SetUserAPI(userAPI)

View file

@ -350,7 +350,7 @@ func (m *DendriteMonolith) Start() {
base, federation, rsAPI, base.Caches, keyRing, true,
)
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI)
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI, rsAPI)
m.userAPI = userapi.NewInternalAPI(base, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI, rsAPI, base.PushGatewayHTTPClient())
keyAPI.SetUserAPI(m.userAPI)

View file

@ -165,7 +165,7 @@ func (m *DendriteMonolith) Start() {
base, federation, rsAPI, base.Caches, keyRing, true,
)
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation)
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation, rsAPI)
userAPI := userapi.NewInternalAPI(base, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI, rsAPI, base.PushGatewayHTTPClient())
keyAPI.SetUserAPI(userAPI)

134
clientapi/admin_test.go Normal file
View file

@ -0,0 +1,134 @@
package clientapi
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/keyserver"
"github.com/matrix-org/dendrite/roomserver"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/tidwall/gjson"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig"
"github.com/matrix-org/dendrite/userapi"
uapi "github.com/matrix-org/dendrite/userapi/api"
)
func TestAdminResetPassword(t *testing.T) {
aliceAdmin := test.NewUser(t, test.WithAccountType(uapi.AccountTypeAdmin))
bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser))
vhUser := &test.User{ID: "@vhuser:vh1"}
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
base, baseClose := testrig.CreateBaseDendrite(t, dbType)
defer baseClose()
// add a vhost
base.Cfg.Global.VirtualHosts = append(base.Cfg.Global.VirtualHosts, &config.VirtualHost{
SigningIdentity: gomatrixserverlib.SigningIdentity{ServerName: "vh1"},
})
rsAPI := roomserver.NewInternalAPI(base)
// Needed for changing the password/login
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI)
userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil)
keyAPI.SetUserAPI(userAPI)
// We mostly need the userAPI for this test, so nil for other APIs/caches etc.
AddPublicRoutes(base, nil, nil, nil, nil, nil, userAPI, nil, nil, nil)
// Create the users in the userapi and login
accessTokens := map[*test.User]string{
aliceAdmin: "",
bob: "",
vhUser: "",
}
for u := range accessTokens {
localpart, serverName, _ := gomatrixserverlib.SplitID('@', u.ID)
userRes := &uapi.PerformAccountCreationResponse{}
password := util.RandomString(8)
if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{
AccountType: u.AccountType,
Localpart: localpart,
ServerName: serverName,
Password: password,
}, userRes); err != nil {
t.Errorf("failed to create account: %s", err)
}
req := test.NewRequest(t, http.MethodPost, "/_matrix/client/v3/login", test.WithJSONBody(t, map[string]interface{}{
"type": authtypes.LoginTypePassword,
"identifier": map[string]interface{}{
"type": "m.id.user",
"user": u.ID,
},
"password": password,
}))
rec := httptest.NewRecorder()
base.PublicClientAPIMux.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("failed to login: %s", rec.Body.String())
}
accessTokens[u] = gjson.GetBytes(rec.Body.Bytes(), "access_token").String()
}
testCases := []struct {
name string
requestingUser *test.User
userID string
requestOpt test.HTTPRequestOpt
wantOK bool
withHeader bool
}{
{name: "Missing auth", requestingUser: bob, wantOK: false, userID: bob.ID},
{name: "Bob is denied access", requestingUser: bob, wantOK: false, withHeader: true, userID: bob.ID},
{name: "Alice is allowed access", requestingUser: aliceAdmin, wantOK: true, withHeader: true, userID: bob.ID, requestOpt: test.WithJSONBody(t, map[string]interface{}{
"password": util.RandomString(8),
})},
{name: "missing userID does not call function", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: ""}, // this 404s
{name: "rejects empty password", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: bob.ID, requestOpt: test.WithJSONBody(t, map[string]interface{}{
"password": "",
})},
{name: "rejects unknown server name", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: "@doesnotexist:localhost", requestOpt: test.WithJSONBody(t, map[string]interface{}{})},
{name: "rejects unknown user", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: "@doesnotexist:test", requestOpt: test.WithJSONBody(t, map[string]interface{}{})},
{name: "allows changing password for different vhost", requestingUser: aliceAdmin, wantOK: true, withHeader: true, userID: vhUser.ID, requestOpt: test.WithJSONBody(t, map[string]interface{}{
"password": util.RandomString(8),
})},
{name: "rejects existing user, missing body", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: bob.ID},
{name: "rejects invalid userID", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: "!notauserid:test", requestOpt: test.WithJSONBody(t, map[string]interface{}{})},
{name: "rejects invalid json", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: bob.ID, requestOpt: test.WithJSONBody(t, `{invalidJSON}`)},
{name: "rejects too weak password", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: bob.ID, requestOpt: test.WithJSONBody(t, map[string]interface{}{
"password": util.RandomString(6),
})},
{name: "rejects too long password", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: bob.ID, requestOpt: test.WithJSONBody(t, map[string]interface{}{
"password": util.RandomString(513),
})},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req := test.NewRequest(t, http.MethodPost, "/_dendrite/admin/resetPassword/"+tc.userID)
if tc.requestOpt != nil {
req = test.NewRequest(t, http.MethodPost, "/_dendrite/admin/resetPassword/"+tc.userID, tc.requestOpt)
}
if tc.withHeader {
req.Header.Set("Authorization", "Bearer "+accessTokens[tc.requestingUser])
}
rec := httptest.NewRecorder()
base.DendriteAdminMux.ServeHTTP(rec, req)
t.Logf("%s", rec.Body.String())
if tc.wantOK && rec.Code != http.StatusOK {
t.Fatalf("expected http status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String())
}
})
}
})
}

View file

@ -15,6 +15,8 @@
package clientapi
import (
"github.com/matrix-org/gomatrixserverlib"
appserviceAPI "github.com/matrix-org/dendrite/appservice/api"
"github.com/matrix-org/dendrite/clientapi/api"
"github.com/matrix-org/dendrite/clientapi/producers"
@ -26,7 +28,6 @@ import (
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/jetstream"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
)
// AddPublicRoutes sets up and registers HTTP handlers for the ClientAPI component.
@ -57,10 +58,7 @@ func AddPublicRoutes(
}
routing.Setup(
base.PublicClientAPIMux,
base.PublicWellKnownAPIMux,
base.SynapseAdminMux,
base.DendriteAdminMux,
base,
cfg, rsAPI, asAPI,
userAPI, userDirectoryProvider, federation,
syncProducer, transactionsCache, fsAPI, keyAPI,

View file

@ -7,6 +7,7 @@ import (
"time"
"github.com/gorilla/mux"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/nats-io/nats.go"
@ -98,20 +99,40 @@ func AdminEvacuateUser(req *http.Request, cfg *config.ClientAPI, device *userapi
}
func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse {
if req.Body == nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.Unknown("Missing request body"),
}
}
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
serverName := cfg.Matrix.ServerName
localpart, ok := vars["localpart"]
if !ok {
var localpart string
userID := vars["userID"]
localpart, serverName, err := cfg.Matrix.SplitLocalID('@', userID)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.MissingArgument("Expecting user localpart."),
JSON: jsonerror.InvalidArgumentValue(err.Error()),
}
}
if l, s, err := cfg.Matrix.SplitLocalID('@', localpart); err == nil {
localpart, serverName = l, s
accAvailableResp := &userapi.QueryAccountAvailabilityResponse{}
if err = userAPI.QueryAccountAvailability(req.Context(), &userapi.QueryAccountAvailabilityRequest{
Localpart: localpart,
ServerName: serverName,
}, accAvailableResp); err != nil {
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: jsonerror.InternalAPIError(req.Context(), err),
}
}
if accAvailableResp.Available {
return util.JSONResponse{
Code: http.StatusNotFound,
JSON: jsonerror.Unknown("User does not exist"),
}
}
request := struct {
Password string `json:"password"`
@ -128,6 +149,11 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap
JSON: jsonerror.MissingArgument("Expecting non-empty password."),
}
}
if resErr := internal.ValidatePassword(request.Password); resErr != nil {
return *resErr
}
updateReq := &userapi.PerformPasswordUpdateRequest{
Localpart: localpart,
ServerName: serverName,

View file

@ -37,6 +37,7 @@ func JoinRoomByIDOrAlias(
joinReq := roomserverAPI.PerformJoinRequest{
RoomIDOrAlias: roomIDOrAlias,
UserID: device.UserID,
IsGuest: device.AccountType == api.AccountTypeGuest,
Content: map[string]interface{}{},
}
joinRes := roomserverAPI.PerformJoinResponse{}
@ -84,7 +85,14 @@ func JoinRoomByIDOrAlias(
if err := rsAPI.PerformJoin(req.Context(), &joinReq, &joinRes); err != nil {
done <- jsonerror.InternalAPIError(req.Context(), err)
} else if joinRes.Error != nil {
done <- joinRes.Error.JSONResponse()
if joinRes.Error.Code == roomserverAPI.PerformErrorNotAllowed && device.AccountType == api.AccountTypeGuest {
done <- util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.GuestAccessForbidden(joinRes.Error.Msg),
}
} else {
done <- joinRes.Error.JSONResponse()
}
} else {
done <- util.JSONResponse{
Code: http.StatusOK,

View file

@ -0,0 +1,158 @@
package routing
import (
"bytes"
"context"
"net/http"
"testing"
"time"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/appservice"
"github.com/matrix-org/dendrite/keyserver"
"github.com/matrix-org/dendrite/roomserver"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig"
"github.com/matrix-org/dendrite/userapi"
uapi "github.com/matrix-org/dendrite/userapi/api"
)
func TestJoinRoomByIDOrAlias(t *testing.T) {
alice := test.NewUser(t)
bob := test.NewUser(t)
charlie := test.NewUser(t, test.WithAccountType(uapi.AccountTypeGuest))
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
base, baseClose := testrig.CreateBaseDendrite(t, dbType)
defer baseClose()
rsAPI := roomserver.NewInternalAPI(base)
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI)
userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil)
asAPI := appservice.NewInternalAPI(base, userAPI, rsAPI)
rsAPI.SetFederationAPI(nil, nil) // creates the rs.Inputer etc
// Create the users in the userapi
for _, u := range []*test.User{alice, bob, charlie} {
localpart, serverName, _ := gomatrixserverlib.SplitID('@', u.ID)
userRes := &uapi.PerformAccountCreationResponse{}
if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{
AccountType: u.AccountType,
Localpart: localpart,
ServerName: serverName,
Password: "someRandomPassword",
}, userRes); err != nil {
t.Errorf("failed to create account: %s", err)
}
}
aliceDev := &uapi.Device{UserID: alice.ID}
bobDev := &uapi.Device{UserID: bob.ID}
charlieDev := &uapi.Device{UserID: charlie.ID, AccountType: uapi.AccountTypeGuest}
// create a room with disabled guest access and invite Bob
resp := createRoom(ctx, createRoomRequest{
Name: "testing",
IsDirect: true,
Topic: "testing",
Visibility: "public",
Preset: presetPublicChat,
RoomAliasName: "alias",
Invite: []string{bob.ID},
GuestCanJoin: false,
}, aliceDev, &base.Cfg.ClientAPI, userAPI, rsAPI, asAPI, time.Now())
crResp, ok := resp.JSON.(createRoomResponse)
if !ok {
t.Fatalf("response is not a createRoomResponse: %+v", resp)
}
// create a room with guest access enabled and invite Charlie
resp = createRoom(ctx, createRoomRequest{
Name: "testing",
IsDirect: true,
Topic: "testing",
Visibility: "public",
Preset: presetPublicChat,
Invite: []string{charlie.ID},
GuestCanJoin: true,
}, aliceDev, &base.Cfg.ClientAPI, userAPI, rsAPI, asAPI, time.Now())
crRespWithGuestAccess, ok := resp.JSON.(createRoomResponse)
if !ok {
t.Fatalf("response is not a createRoomResponse: %+v", resp)
}
// Dummy request
body := &bytes.Buffer{}
req, err := http.NewRequest(http.MethodPost, "/?server_name=test", body)
if err != nil {
t.Fatal(err)
}
testCases := []struct {
name string
device *uapi.Device
roomID string
wantHTTP200 bool
}{
{
name: "User can join successfully by alias",
device: bobDev,
roomID: crResp.RoomAlias,
wantHTTP200: true,
},
{
name: "User can join successfully by roomID",
device: bobDev,
roomID: crResp.RoomID,
wantHTTP200: true,
},
{
name: "join is forbidden if user is guest",
device: charlieDev,
roomID: crResp.RoomID,
},
{
name: "room does not exist",
device: aliceDev,
roomID: "!doesnotexist:test",
},
{
name: "user from different server",
device: &uapi.Device{UserID: "@wrong:server"},
roomID: crResp.RoomAlias,
},
{
name: "user doesn't exist locally",
device: &uapi.Device{UserID: "@doesnotexist:test"},
roomID: crResp.RoomAlias,
},
{
name: "invalid room ID",
device: aliceDev,
roomID: "invalidRoomID",
},
{
name: "roomAlias does not exist",
device: aliceDev,
roomID: "#doesnotexist:test",
},
{
name: "room with guest_access event",
device: charlieDev,
roomID: crRespWithGuestAccess.RoomID,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
joinResp := JoinRoomByIDOrAlias(req, tc.device, rsAPI, userAPI, tc.roomID)
if tc.wantHTTP200 && !joinResp.Is2xx() {
t.Fatalf("expected join room to succeed, but didn't: %+v", joinResp)
}
})
}
})
}

View file

@ -7,6 +7,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
@ -81,7 +82,7 @@ func Password(
sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword)
// Check the new password strength.
if resErr = validatePassword(r.NewPassword); resErr != nil {
if resErr = internal.ValidatePassword(r.NewPassword); resErr != nil {
return *resErr
}

View file

@ -30,6 +30,7 @@ import (
"sync"
"time"
"github.com/matrix-org/dendrite/internal"
"github.com/tidwall/gjson"
"github.com/matrix-org/dendrite/internal/eventutil"
@ -60,8 +61,6 @@ var (
)
const (
minPasswordLength = 8 // http://matrix.org/docs/spec/client_server/r0.2.0.html#password-based
maxPasswordLength = 512 // https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161
maxUsernameLength = 254 // http://matrix.org/speculator/spec/HEAD/intro.html#user-identifiers TODO account for domain
sessionIDLength = 24
)
@ -315,23 +314,6 @@ func validateApplicationServiceUsername(localpart string, domain gomatrixserverl
return nil
}
// validatePassword returns an error response if the password is invalid
func validatePassword(password string) *util.JSONResponse {
// https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161
if len(password) > maxPasswordLength {
return &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON(fmt.Sprintf("'password' >%d characters", maxPasswordLength)),
}
} else if len(password) > 0 && len(password) < minPasswordLength {
return &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.WeakPassword(fmt.Sprintf("password too weak: min %d chars", minPasswordLength)),
}
}
return nil
}
// validateRecaptcha returns an error response if the captcha response is invalid
func validateRecaptcha(
cfg *config.ClientAPI,
@ -636,7 +618,7 @@ func Register(
return *resErr
}
}
if resErr := validatePassword(r.Password); resErr != nil {
if resErr := internal.ValidatePassword(r.Password); resErr != nil {
return *resErr
}
@ -1138,7 +1120,7 @@ func handleSharedSecretRegistration(cfg *config.ClientAPI, userAPI userapi.Clien
if resErr := validateUsername(ssrr.User, cfg.Matrix.ServerName); resErr != nil {
return *resErr
}
if resErr := validatePassword(ssrr.Password); resErr != nil {
if resErr := internal.ValidatePassword(ssrr.Password); resErr != nil {
return *resErr
}
deviceID := "shared_secret_registration"

View file

@ -20,6 +20,7 @@ import (
"strings"
"github.com/gorilla/mux"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/nats-io/nats.go"
@ -49,7 +50,7 @@ import (
// applied:
// nolint: gocyclo
func Setup(
publicAPIMux, wkMux, synapseAdminRouter, dendriteAdminRouter *mux.Router,
base *base.BaseDendrite,
cfg *config.ClientAPI,
rsAPI roomserverAPI.ClientRoomserverAPI,
asAPI appserviceAPI.AppServiceInternalAPI,
@ -63,7 +64,14 @@ func Setup(
extRoomsProvider api.ExtraPublicRoomsProvider,
mscCfg *config.MSCs, natsClient *nats.Conn,
) {
prometheus.MustRegister(amtRegUsers, sendEventDuration)
publicAPIMux := base.PublicClientAPIMux
wkMux := base.PublicWellKnownAPIMux
synapseAdminRouter := base.SynapseAdminMux
dendriteAdminRouter := base.DendriteAdminMux
if base.EnableMetrics {
prometheus.MustRegister(amtRegUsers, sendEventDuration)
}
rateLimits := httputil.NewRateLimits(&cfg.RateLimiting)
userInteractiveAuth := auth.NewUserInteractive(userAPI, cfg)
@ -631,7 +639,7 @@ func Setup(
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
v3mux.Handle("/auth/{authType}/fallback/web",
httputil.MakeHTMLAPI("auth_fallback", func(w http.ResponseWriter, req *http.Request) *util.JSONResponse {
httputil.MakeHTMLAPI("auth_fallback", base.EnableMetrics, func(w http.ResponseWriter, req *http.Request) *util.JSONResponse {
vars := mux.Vars(req)
return AuthFallback(w, req, vars["authType"], cfg)
}),

View file

@ -213,7 +213,7 @@ func main() {
base, federation, rsAPI, base.Caches, keyRing, true,
)
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI)
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI, rsComponent)
userAPI := userapi.NewInternalAPI(base, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient())
keyAPI.SetUserAPI(userAPI)

View file

@ -157,11 +157,12 @@ func main() {
serverKeyAPI := &signing.YggdrasilKeys{}
keyRing := serverKeyAPI.KeyRing()
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation)
rsComponent := roomserver.NewInternalAPI(
base,
)
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation, rsComponent)
rsAPI := rsComponent
userAPI := userapi.NewInternalAPI(base, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient())

View file

@ -95,7 +95,7 @@ func main() {
}
keyRing := fsAPI.KeyRing()
keyImpl := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI)
keyImpl := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI, rsAPI)
keyAPI := keyImpl
if base.UseHTTPAPIs {
keyserver.AddInternalRoutes(base.InternalAPIMux, keyAPI, base.EnableMetrics)

View file

@ -22,7 +22,8 @@ import (
func KeyServer(base *basepkg.BaseDendrite, cfg *config.Dendrite) {
fsAPI := base.FederationAPIHTTPClient()
intAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI)
rsAPI := base.RoomserverHTTPClient()
intAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI, rsAPI)
intAPI.SetUserAPI(base.UserAPIClient())
keyserver.AddInternalRoutes(base.InternalAPIMux, intAPI, base.EnableMetrics)

View file

@ -91,7 +91,7 @@ Please use PostgreSQL wherever possible, especially if you are planning to run a
## Dendrite is using a lot of CPU
Generally speaking, you should expect to see some CPU spikes, particularly if you are joining or participating in large rooms. However, constant/sustained high CPU usage is not expected - if you are experiencing that, please join `#dendrite-dev:matrix.org` and let us know what you were doing when the
CPU usage shot up, or file a GitHub issue. If you can take a [CPU profile](PROFILING.md) then that would
CPU usage shot up, or file a GitHub issue. If you can take a [CPU profile](development/PROFILING.md) then that would
be a huge help too, as that will help us to understand where the CPU time is going.
## Dendrite is using a lot of RAM
@ -99,7 +99,7 @@ be a huge help too, as that will help us to understand where the CPU time is goi
As above with CPU usage, some memory spikes are expected if Dendrite is doing particularly heavy work
at a given instant. However, if it is using more RAM than you expect for a long time, that's probably
not expected. Join `#dendrite-dev:matrix.org` and let us know what you were doing when the memory usage
ballooned, or file a GitHub issue if you can. If you can take a [memory profile](PROFILING.md) then that
ballooned, or file a GitHub issue if you can. If you can take a [memory profile](development/PROFILING.md) then that
would be a huge help too, as that will help us to understand where the memory usage is happening.
## Dendrite is running out of PostgreSQL database connections

View file

@ -231,9 +231,9 @@ GEM
jekyll-seo-tag (~> 2.1)
minitest (5.15.0)
multipart-post (2.1.1)
nokogiri (1.13.9-arm64-darwin)
nokogiri (1.13.10-arm64-darwin)
racc (~> 1.4)
nokogiri (1.13.9-x86_64-linux)
nokogiri (1.13.10-x86_64-linux)
racc (~> 1.4)
octokit (4.22.0)
faraday (>= 0.9)
@ -241,7 +241,7 @@ GEM
pathutil (0.16.2)
forwardable-extended (~> 2.6)
public_suffix (4.0.7)
racc (1.6.0)
racc (1.6.1)
rb-fsevent (0.11.1)
rb-inotify (0.10.1)
ffi (~> 1.0)

View file

@ -44,7 +44,9 @@ This endpoint will instruct Dendrite to part the given local `userID` in the URL
all rooms which they are currently joined. A JSON body will be returned containing
the room IDs of all affected rooms.
## POST `/_dendrite/admin/resetPassword/{localpart}`
## POST `/_dendrite/admin/resetPassword/{userID}`
Reset the password of a local user.
Request body format:
@ -54,9 +56,6 @@ Request body format:
}
```
Reset the password of a local user. The `localpart` is the username only, i.e. if
the full user ID is `@alice:domain.com` then the local part is `alice`.
## GET `/_dendrite/admin/fulltext/reindex`
This endpoint instructs Dendrite to reindex all searchable events (`m.room.message`, `m.room.topic` and `m.room.name`). An empty JSON body will be returned immediately.

View file

@ -9,6 +9,28 @@ permalink: /development/contributing
Everyone is welcome to contribute to Dendrite! We aim to make it as easy as
possible to get started.
## Contribution types
We are a small team maintaining a large project. As a result, we cannot merge every feature, even if it
is bug-free and useful, because we then commit to maintaining it indefinitely. We will always accept:
- bug fixes
- security fixes (please responsibly disclose via security@matrix.org *before* creating pull requests)
We will accept the following with caveats:
- documentation fixes, provided they do not add additional instructions which can end up going out-of-date,
e.g example configs, shell commands.
- performance fixes, provided they do not add significantly more maintenance burden.
- additional functionality on existing features, provided the functionality is small and maintainable.
- additional functionality that, in its absence, would impact the ecosystem e.g spam and abuse mitigations
- test-only changes, provided they help improve coverage or test tricky code.
The following items are at risk of not being accepted:
- Configuration or CLI changes, particularly ones which increase the overall configuration surface.
The following items are unlikely to be accepted into a main Dendrite release for now:
- New MSC implementations.
- New features which are not in the specification.
## Sign off
We require that everyone who contributes to the project signs off their contributions
@ -35,7 +57,7 @@ to do so for future contributions.
## Getting up and running
See the [Installation](installation) section for information on how to build an
See the [Installation](../installation) section for information on how to build an
instance of Dendrite. You will likely need this in order to test your changes.
## Code style
@ -129,7 +151,7 @@ significant amount of CPU and RAM.
Once the code builds, run [Sytest](https://github.com/matrix-org/sytest)
according to the guide in
[docs/sytest.md](https://github.com/matrix-org/dendrite/blob/main/docs/sytest.md#using-a-sytest-docker-image)
[docs/development/sytest.md](https://github.com/matrix-org/dendrite/blob/main/docs/development/sytest.md#using-a-sytest-docker-image)
so you can see whether something is being broken and whether there are newly
passing tests.

View file

@ -232,7 +232,7 @@ func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent, rew
}
func (s *OutputRoomEventConsumer) sendPresence(roomID string, addedJoined []types.JoinedHost) {
joined := make([]gomatrixserverlib.ServerName, len(addedJoined))
joined := make([]gomatrixserverlib.ServerName, 0, len(addedJoined))
for _, added := range addedJoined {
joined = append(joined, added.ServerName)
}

View file

@ -221,28 +221,6 @@ func (d *fakeDatabase) CleanEDUs(ctx context.Context, serverName gomatrixserverl
return nil
}
func (d *fakeDatabase) GetPendingPDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error) {
d.dbMutex.Lock()
defer d.dbMutex.Unlock()
var count int64
if pdus, ok := d.associatedPDUs[serverName]; ok {
count = int64(len(pdus))
}
return count, nil
}
func (d *fakeDatabase) GetPendingEDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error) {
d.dbMutex.Lock()
defer d.dbMutex.Unlock()
var count int64
if edus, ok := d.associatedEDUs[serverName]; ok {
count = int64(len(edus))
}
return count, nil
}
func (d *fakeDatabase) GetPendingPDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) {
d.dbMutex.Lock()
defer d.dbMutex.Unlock()

View file

@ -45,9 +45,6 @@ type Database interface {
CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error
CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error
GetPendingPDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error)
GetPendingEDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error)
GetPendingPDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
GetPendingEDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error)

View file

@ -44,13 +44,13 @@ const selectInboundPeekSQL = "" +
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3"
const selectInboundPeeksSQL = "" +
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1"
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1 ORDER by creation_ts"
const renewInboundPeekSQL = "" +
"UPDATE federationsender_inbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5"
const deleteInboundPeekSQL = "" +
"DELETE FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2"
"DELETE FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3"
const deleteInboundPeeksSQL = "" +
"DELETE FROM federationsender_inbound_peeks WHERE room_id = $1"
@ -74,25 +74,15 @@ func NewPostgresInboundPeeksTable(db *sql.DB) (s *inboundPeeksStatements, err er
return
}
if s.insertInboundPeekStmt, err = db.Prepare(insertInboundPeekSQL); err != nil {
return
}
if s.selectInboundPeekStmt, err = db.Prepare(selectInboundPeekSQL); err != nil {
return
}
if s.selectInboundPeeksStmt, err = db.Prepare(selectInboundPeeksSQL); err != nil {
return
}
if s.renewInboundPeekStmt, err = db.Prepare(renewInboundPeekSQL); err != nil {
return
}
if s.deleteInboundPeeksStmt, err = db.Prepare(deleteInboundPeeksSQL); err != nil {
return
}
if s.deleteInboundPeekStmt, err = db.Prepare(deleteInboundPeekSQL); err != nil {
return
}
return
return s, sqlutil.StatementList{
{&s.insertInboundPeekStmt, insertInboundPeekSQL},
{&s.selectInboundPeekStmt, selectInboundPeekSQL},
{&s.selectInboundPeekStmt, selectInboundPeekSQL},
{&s.selectInboundPeeksStmt, selectInboundPeeksSQL},
{&s.renewInboundPeekStmt, renewInboundPeekSQL},
{&s.deleteInboundPeeksStmt, deleteInboundPeeksSQL},
{&s.deleteInboundPeekStmt, deleteInboundPeekSQL},
}.Prepare(db)
}
func (s *inboundPeeksStatements) InsertInboundPeek(

View file

@ -44,13 +44,13 @@ const selectOutboundPeekSQL = "" +
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3"
const selectOutboundPeeksSQL = "" +
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1"
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1 ORDER BY creation_ts"
const renewOutboundPeekSQL = "" +
"UPDATE federationsender_outbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5"
const deleteOutboundPeekSQL = "" +
"DELETE FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2"
"DELETE FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3"
const deleteOutboundPeeksSQL = "" +
"DELETE FROM federationsender_outbound_peeks WHERE room_id = $1"
@ -74,25 +74,14 @@ func NewPostgresOutboundPeeksTable(db *sql.DB) (s *outboundPeeksStatements, err
return
}
if s.insertOutboundPeekStmt, err = db.Prepare(insertOutboundPeekSQL); err != nil {
return
}
if s.selectOutboundPeekStmt, err = db.Prepare(selectOutboundPeekSQL); err != nil {
return
}
if s.selectOutboundPeeksStmt, err = db.Prepare(selectOutboundPeeksSQL); err != nil {
return
}
if s.renewOutboundPeekStmt, err = db.Prepare(renewOutboundPeekSQL); err != nil {
return
}
if s.deleteOutboundPeeksStmt, err = db.Prepare(deleteOutboundPeeksSQL); err != nil {
return
}
if s.deleteOutboundPeekStmt, err = db.Prepare(deleteOutboundPeekSQL); err != nil {
return
}
return
return s, sqlutil.StatementList{
{&s.insertOutboundPeekStmt, insertOutboundPeekSQL},
{&s.selectOutboundPeekStmt, selectOutboundPeekSQL},
{&s.selectOutboundPeeksStmt, selectOutboundPeeksSQL},
{&s.renewOutboundPeekStmt, renewOutboundPeekSQL},
{&s.deleteOutboundPeeksStmt, deleteOutboundPeeksSQL},
{&s.deleteOutboundPeekStmt, deleteOutboundPeekSQL},
}.Prepare(db)
}
func (s *outboundPeeksStatements) InsertOutboundPeek(

View file

@ -62,10 +62,6 @@ const selectQueueEDUReferenceJSONCountSQL = "" +
"SELECT COUNT(*) FROM federationsender_queue_edus" +
" WHERE json_nid = $1"
const selectQueueEDUCountSQL = "" +
"SELECT COUNT(*) FROM federationsender_queue_edus" +
" WHERE server_name = $1"
const selectQueueServerNamesSQL = "" +
"SELECT DISTINCT server_name FROM federationsender_queue_edus"
@ -81,7 +77,6 @@ type queueEDUsStatements struct {
deleteQueueEDUStmt *sql.Stmt
selectQueueEDUStmt *sql.Stmt
selectQueueEDUReferenceJSONCountStmt *sql.Stmt
selectQueueEDUCountStmt *sql.Stmt
selectQueueEDUServerNamesStmt *sql.Stmt
selectExpiredEDUsStmt *sql.Stmt
deleteExpiredEDUsStmt *sql.Stmt
@ -116,7 +111,6 @@ func (s *queueEDUsStatements) Prepare() error {
{&s.deleteQueueEDUStmt, deleteQueueEDUSQL},
{&s.selectQueueEDUStmt, selectQueueEDUSQL},
{&s.selectQueueEDUReferenceJSONCountStmt, selectQueueEDUReferenceJSONCountSQL},
{&s.selectQueueEDUCountStmt, selectQueueEDUCountSQL},
{&s.selectQueueEDUServerNamesStmt, selectQueueServerNamesSQL},
{&s.selectExpiredEDUsStmt, selectExpiredEDUsSQL},
{&s.deleteExpiredEDUsStmt, deleteExpiredEDUsSQL},
@ -186,21 +180,6 @@ func (s *queueEDUsStatements) SelectQueueEDUReferenceJSONCount(
return count, err
}
func (s *queueEDUsStatements) SelectQueueEDUCount(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) (int64, error) {
var count int64
stmt := sqlutil.TxStmt(txn, s.selectQueueEDUCountStmt)
err := stmt.QueryRowContext(ctx, serverName).Scan(&count)
if err == sql.ErrNoRows {
// It's acceptable for there to be no rows referencing a given
// JSON NID but it's not an error condition. Just return as if
// there's a zero count.
return 0, nil
}
return count, err
}
func (s *queueEDUsStatements) SelectQueueEDUServerNames(
ctx context.Context, txn *sql.Tx,
) ([]gomatrixserverlib.ServerName, error) {

View file

@ -58,10 +58,6 @@ const selectQueuePDUReferenceJSONCountSQL = "" +
"SELECT COUNT(*) FROM federationsender_queue_pdus" +
" WHERE json_nid = $1"
const selectQueuePDUsCountSQL = "" +
"SELECT COUNT(*) FROM federationsender_queue_pdus" +
" WHERE server_name = $1"
const selectQueuePDUServerNamesSQL = "" +
"SELECT DISTINCT server_name FROM federationsender_queue_pdus"
@ -71,7 +67,6 @@ type queuePDUsStatements struct {
deleteQueuePDUsStmt *sql.Stmt
selectQueuePDUsStmt *sql.Stmt
selectQueuePDUReferenceJSONCountStmt *sql.Stmt
selectQueuePDUsCountStmt *sql.Stmt
selectQueuePDUServerNamesStmt *sql.Stmt
}
@ -95,9 +90,6 @@ func NewPostgresQueuePDUsTable(db *sql.DB) (s *queuePDUsStatements, err error) {
if s.selectQueuePDUReferenceJSONCountStmt, err = s.db.Prepare(selectQueuePDUReferenceJSONCountSQL); err != nil {
return
}
if s.selectQueuePDUsCountStmt, err = s.db.Prepare(selectQueuePDUsCountSQL); err != nil {
return
}
if s.selectQueuePDUServerNamesStmt, err = s.db.Prepare(selectQueuePDUServerNamesSQL); err != nil {
return
}
@ -146,21 +138,6 @@ func (s *queuePDUsStatements) SelectQueuePDUReferenceJSONCount(
return count, err
}
func (s *queuePDUsStatements) SelectQueuePDUCount(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) (int64, error) {
var count int64
stmt := sqlutil.TxStmt(txn, s.selectQueuePDUsCountStmt)
err := stmt.QueryRowContext(ctx, serverName).Scan(&count)
if err == sql.ErrNoRows {
// It's acceptable for there to be no rows referencing a given
// JSON NID but it's not an error condition. Just return as if
// there's a zero count.
return 0, nil
}
return count, err
}
func (s *queuePDUsStatements) SelectQueuePDUs(
ctx context.Context, txn *sql.Tx,
serverName gomatrixserverlib.ServerName,

View file

@ -162,15 +162,6 @@ func (d *Database) CleanEDUs(
})
}
// GetPendingEDUCount returns the number of EDUs waiting to be
// sent for a given servername.
func (d *Database) GetPendingEDUCount(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
) (int64, error) {
return d.FederationQueueEDUs.SelectQueueEDUCount(ctx, nil, serverName)
}
// GetPendingServerNames returns the server names that have EDUs
// waiting to be sent.
func (d *Database) GetPendingEDUServerNames(

View file

@ -141,15 +141,6 @@ func (d *Database) CleanPDUs(
})
}
// GetPendingPDUCount returns the number of PDUs waiting to be
// sent for a given servername.
func (d *Database) GetPendingPDUCount(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
) (int64, error) {
return d.FederationQueuePDUs.SelectQueuePDUCount(ctx, nil, serverName)
}
// GetPendingServerNames returns the server names that have PDUs
// waiting to be sent.
func (d *Database) GetPendingPDUServerNames(

View file

@ -44,13 +44,13 @@ const selectInboundPeekSQL = "" +
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3"
const selectInboundPeeksSQL = "" +
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1"
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1 ORDER BY creation_ts"
const renewInboundPeekSQL = "" +
"UPDATE federationsender_inbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5"
const deleteInboundPeekSQL = "" +
"DELETE FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2"
"DELETE FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3"
const deleteInboundPeeksSQL = "" +
"DELETE FROM federationsender_inbound_peeks WHERE room_id = $1"
@ -74,25 +74,15 @@ func NewSQLiteInboundPeeksTable(db *sql.DB) (s *inboundPeeksStatements, err erro
return
}
if s.insertInboundPeekStmt, err = db.Prepare(insertInboundPeekSQL); err != nil {
return
}
if s.selectInboundPeekStmt, err = db.Prepare(selectInboundPeekSQL); err != nil {
return
}
if s.selectInboundPeeksStmt, err = db.Prepare(selectInboundPeeksSQL); err != nil {
return
}
if s.renewInboundPeekStmt, err = db.Prepare(renewInboundPeekSQL); err != nil {
return
}
if s.deleteInboundPeeksStmt, err = db.Prepare(deleteInboundPeeksSQL); err != nil {
return
}
if s.deleteInboundPeekStmt, err = db.Prepare(deleteInboundPeekSQL); err != nil {
return
}
return
return s, sqlutil.StatementList{
{&s.insertInboundPeekStmt, insertInboundPeekSQL},
{&s.selectInboundPeekStmt, selectInboundPeekSQL},
{&s.selectInboundPeekStmt, selectInboundPeekSQL},
{&s.selectInboundPeeksStmt, selectInboundPeeksSQL},
{&s.renewInboundPeekStmt, renewInboundPeekSQL},
{&s.deleteInboundPeeksStmt, deleteInboundPeeksSQL},
{&s.deleteInboundPeekStmt, deleteInboundPeekSQL},
}.Prepare(db)
}
func (s *inboundPeeksStatements) InsertInboundPeek(

View file

@ -44,13 +44,13 @@ const selectOutboundPeekSQL = "" +
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3"
const selectOutboundPeeksSQL = "" +
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1"
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1 ORDER BY creation_ts"
const renewOutboundPeekSQL = "" +
"UPDATE federationsender_outbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5"
const deleteOutboundPeekSQL = "" +
"DELETE FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2"
"DELETE FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3"
const deleteOutboundPeeksSQL = "" +
"DELETE FROM federationsender_outbound_peeks WHERE room_id = $1"
@ -74,25 +74,14 @@ func NewSQLiteOutboundPeeksTable(db *sql.DB) (s *outboundPeeksStatements, err er
return
}
if s.insertOutboundPeekStmt, err = db.Prepare(insertOutboundPeekSQL); err != nil {
return
}
if s.selectOutboundPeekStmt, err = db.Prepare(selectOutboundPeekSQL); err != nil {
return
}
if s.selectOutboundPeeksStmt, err = db.Prepare(selectOutboundPeeksSQL); err != nil {
return
}
if s.renewOutboundPeekStmt, err = db.Prepare(renewOutboundPeekSQL); err != nil {
return
}
if s.deleteOutboundPeeksStmt, err = db.Prepare(deleteOutboundPeeksSQL); err != nil {
return
}
if s.deleteOutboundPeekStmt, err = db.Prepare(deleteOutboundPeekSQL); err != nil {
return
}
return
return s, sqlutil.StatementList{
{&s.insertOutboundPeekStmt, insertOutboundPeekSQL},
{&s.selectOutboundPeekStmt, selectOutboundPeekSQL},
{&s.selectOutboundPeeksStmt, selectOutboundPeeksSQL},
{&s.renewOutboundPeekStmt, renewOutboundPeekSQL},
{&s.deleteOutboundPeeksStmt, deleteOutboundPeeksSQL},
{&s.deleteOutboundPeekStmt, deleteOutboundPeekSQL},
}.Prepare(db)
}
func (s *outboundPeeksStatements) InsertOutboundPeek(

View file

@ -63,10 +63,6 @@ const selectQueueEDUReferenceJSONCountSQL = "" +
"SELECT COUNT(*) FROM federationsender_queue_edus" +
" WHERE json_nid = $1"
const selectQueueEDUCountSQL = "" +
"SELECT COUNT(*) FROM federationsender_queue_edus" +
" WHERE server_name = $1"
const selectQueueServerNamesSQL = "" +
"SELECT DISTINCT server_name FROM federationsender_queue_edus"
@ -82,7 +78,6 @@ type queueEDUsStatements struct {
// deleteQueueEDUStmt *sql.Stmt - prepared at runtime due to variadic
selectQueueEDUStmt *sql.Stmt
selectQueueEDUReferenceJSONCountStmt *sql.Stmt
selectQueueEDUCountStmt *sql.Stmt
selectQueueEDUServerNamesStmt *sql.Stmt
selectExpiredEDUsStmt *sql.Stmt
deleteExpiredEDUsStmt *sql.Stmt
@ -116,7 +111,6 @@ func (s *queueEDUsStatements) Prepare() error {
{&s.insertQueueEDUStmt, insertQueueEDUSQL},
{&s.selectQueueEDUStmt, selectQueueEDUSQL},
{&s.selectQueueEDUReferenceJSONCountStmt, selectQueueEDUReferenceJSONCountSQL},
{&s.selectQueueEDUCountStmt, selectQueueEDUCountSQL},
{&s.selectQueueEDUServerNamesStmt, selectQueueServerNamesSQL},
{&s.selectExpiredEDUsStmt, selectExpiredEDUsSQL},
{&s.deleteExpiredEDUsStmt, deleteExpiredEDUsSQL},
@ -198,21 +192,6 @@ func (s *queueEDUsStatements) SelectQueueEDUReferenceJSONCount(
return count, err
}
func (s *queueEDUsStatements) SelectQueueEDUCount(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) (int64, error) {
var count int64
stmt := sqlutil.TxStmt(txn, s.selectQueueEDUCountStmt)
err := stmt.QueryRowContext(ctx, serverName).Scan(&count)
if err == sql.ErrNoRows {
// It's acceptable for there to be no rows referencing a given
// JSON NID but it's not an error condition. Just return as if
// there's a zero count.
return 0, nil
}
return count, err
}
func (s *queueEDUsStatements) SelectQueueEDUServerNames(
ctx context.Context, txn *sql.Tx,
) ([]gomatrixserverlib.ServerName, error) {

View file

@ -66,10 +66,6 @@ const selectQueuePDUsReferenceJSONCountSQL = "" +
"SELECT COUNT(*) FROM federationsender_queue_pdus" +
" WHERE json_nid = $1"
const selectQueuePDUsCountSQL = "" +
"SELECT COUNT(*) FROM federationsender_queue_pdus" +
" WHERE server_name = $1"
const selectQueuePDUsServerNamesSQL = "" +
"SELECT DISTINCT server_name FROM federationsender_queue_pdus"
@ -79,7 +75,6 @@ type queuePDUsStatements struct {
selectQueueNextTransactionIDStmt *sql.Stmt
selectQueuePDUsStmt *sql.Stmt
selectQueueReferenceJSONCountStmt *sql.Stmt
selectQueuePDUsCountStmt *sql.Stmt
selectQueueServerNamesStmt *sql.Stmt
// deleteQueuePDUsStmt *sql.Stmt - prepared at runtime due to variadic
}
@ -107,9 +102,6 @@ func NewSQLiteQueuePDUsTable(db *sql.DB) (s *queuePDUsStatements, err error) {
if s.selectQueueReferenceJSONCountStmt, err = db.Prepare(selectQueuePDUsReferenceJSONCountSQL); err != nil {
return
}
if s.selectQueuePDUsCountStmt, err = db.Prepare(selectQueuePDUsCountSQL); err != nil {
return
}
if s.selectQueueServerNamesStmt, err = db.Prepare(selectQueuePDUsServerNamesSQL); err != nil {
return
}
@ -179,21 +171,6 @@ func (s *queuePDUsStatements) SelectQueuePDUReferenceJSONCount(
return count, err
}
func (s *queuePDUsStatements) SelectQueuePDUCount(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) (int64, error) {
var count int64
stmt := sqlutil.TxStmt(txn, s.selectQueuePDUsCountStmt)
err := stmt.QueryRowContext(ctx, serverName).Scan(&count)
if err == sql.ErrNoRows {
// It's acceptable for there to be no rows referencing a given
// JSON NID but it's not an error condition. Just return as if
// there's a zero count.
return 0, nil
}
return count, err
}
func (s *queuePDUsStatements) SelectQueuePDUs(
ctx context.Context, txn *sql.Tx,
serverName gomatrixserverlib.ServerName,

View file

@ -2,10 +2,12 @@ package storage_test
import (
"context"
"reflect"
"testing"
"time"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/stretchr/testify/assert"
"github.com/matrix-org/dendrite/federationapi/storage"
@ -80,3 +82,167 @@ func TestExpireEDUs(t *testing.T) {
assert.Equal(t, 2, len(data))
})
}
func TestOutboundPeeking(t *testing.T) {
alice := test.NewUser(t)
room := test.NewRoom(t, alice)
_, serverName, _ := gomatrixserverlib.SplitID('@', alice.ID)
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, closeDB := mustCreateFederationDatabase(t, dbType)
defer closeDB()
peekID := util.RandomString(8)
var renewalInterval int64 = 1000
// Add outbound peek
if err := db.AddOutboundPeek(ctx, serverName, room.ID, peekID, renewalInterval); err != nil {
t.Fatal(err)
}
// select the newly inserted peek
outboundPeek1, err := db.GetOutboundPeek(ctx, serverName, room.ID, peekID)
if err != nil {
t.Fatal(err)
}
// Assert fields are set as expected
if outboundPeek1.PeekID != peekID {
t.Fatalf("unexpected outbound peek ID: %s, want %s", outboundPeek1.PeekID, peekID)
}
if outboundPeek1.RoomID != room.ID {
t.Fatalf("unexpected outbound peek room ID: %s, want %s", outboundPeek1.RoomID, peekID)
}
if outboundPeek1.ServerName != serverName {
t.Fatalf("unexpected outbound peek servername: %s, want %s", outboundPeek1.ServerName, serverName)
}
if outboundPeek1.RenewalInterval != renewalInterval {
t.Fatalf("unexpected outbound peek renewal interval: %d, want %d", outboundPeek1.RenewalInterval, renewalInterval)
}
// Renew the peek
if err = db.RenewOutboundPeek(ctx, serverName, room.ID, peekID, 2000); err != nil {
t.Fatal(err)
}
// verify the values changed
outboundPeek2, err := db.GetOutboundPeek(ctx, serverName, room.ID, peekID)
if err != nil {
t.Fatal(err)
}
if reflect.DeepEqual(outboundPeek1, outboundPeek2) {
t.Fatal("expected a change peek, but they are the same")
}
if outboundPeek1.ServerName != outboundPeek2.ServerName {
t.Fatalf("unexpected servername change: %s -> %s", outboundPeek1.ServerName, outboundPeek2.ServerName)
}
if outboundPeek1.RoomID != outboundPeek2.RoomID {
t.Fatalf("unexpected roomID change: %s -> %s", outboundPeek1.RoomID, outboundPeek2.RoomID)
}
// insert some peeks
peekIDs := []string{peekID}
for i := 0; i < 5; i++ {
peekID = util.RandomString(8)
if err = db.AddOutboundPeek(ctx, serverName, room.ID, peekID, 1000); err != nil {
t.Fatal(err)
}
peekIDs = append(peekIDs, peekID)
}
// Now select them
outboundPeeks, err := db.GetOutboundPeeks(ctx, room.ID)
if err != nil {
t.Fatal(err)
}
if len(outboundPeeks) != len(peekIDs) {
t.Fatalf("inserted %d peeks, selected %d", len(peekIDs), len(outboundPeeks))
}
gotPeekIDs := make([]string, 0, len(outboundPeeks))
for _, p := range outboundPeeks {
gotPeekIDs = append(gotPeekIDs, p.PeekID)
}
assert.ElementsMatch(t, gotPeekIDs, peekIDs)
})
}
func TestInboundPeeking(t *testing.T) {
alice := test.NewUser(t)
room := test.NewRoom(t, alice)
_, serverName, _ := gomatrixserverlib.SplitID('@', alice.ID)
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, closeDB := mustCreateFederationDatabase(t, dbType)
defer closeDB()
peekID := util.RandomString(8)
var renewalInterval int64 = 1000
// Add inbound peek
if err := db.AddInboundPeek(ctx, serverName, room.ID, peekID, renewalInterval); err != nil {
t.Fatal(err)
}
// select the newly inserted peek
inboundPeek1, err := db.GetInboundPeek(ctx, serverName, room.ID, peekID)
if err != nil {
t.Fatal(err)
}
// Assert fields are set as expected
if inboundPeek1.PeekID != peekID {
t.Fatalf("unexpected inbound peek ID: %s, want %s", inboundPeek1.PeekID, peekID)
}
if inboundPeek1.RoomID != room.ID {
t.Fatalf("unexpected inbound peek room ID: %s, want %s", inboundPeek1.RoomID, peekID)
}
if inboundPeek1.ServerName != serverName {
t.Fatalf("unexpected inbound peek servername: %s, want %s", inboundPeek1.ServerName, serverName)
}
if inboundPeek1.RenewalInterval != renewalInterval {
t.Fatalf("unexpected inbound peek renewal interval: %d, want %d", inboundPeek1.RenewalInterval, renewalInterval)
}
// Renew the peek
if err = db.RenewInboundPeek(ctx, serverName, room.ID, peekID, 2000); err != nil {
t.Fatal(err)
}
// verify the values changed
inboundPeek2, err := db.GetInboundPeek(ctx, serverName, room.ID, peekID)
if err != nil {
t.Fatal(err)
}
if reflect.DeepEqual(inboundPeek1, inboundPeek2) {
t.Fatal("expected a change peek, but they are the same")
}
if inboundPeek1.ServerName != inboundPeek2.ServerName {
t.Fatalf("unexpected servername change: %s -> %s", inboundPeek1.ServerName, inboundPeek2.ServerName)
}
if inboundPeek1.RoomID != inboundPeek2.RoomID {
t.Fatalf("unexpected roomID change: %s -> %s", inboundPeek1.RoomID, inboundPeek2.RoomID)
}
// insert some peeks
peekIDs := []string{peekID}
for i := 0; i < 5; i++ {
peekID = util.RandomString(8)
if err = db.AddInboundPeek(ctx, serverName, room.ID, peekID, 1000); err != nil {
t.Fatal(err)
}
peekIDs = append(peekIDs, peekID)
}
// Now select them
inboundPeeks, err := db.GetInboundPeeks(ctx, room.ID)
if err != nil {
t.Fatal(err)
}
if len(inboundPeeks) != len(peekIDs) {
t.Fatalf("inserted %d peeks, selected %d", len(peekIDs), len(inboundPeeks))
}
gotPeekIDs := make([]string, 0, len(inboundPeeks))
for _, p := range inboundPeeks {
gotPeekIDs = append(gotPeekIDs, p.PeekID)
}
assert.ElementsMatch(t, gotPeekIDs, peekIDs)
})
}

View file

@ -0,0 +1,149 @@
package tables_test
import (
"context"
"reflect"
"testing"
"github.com/matrix-org/dendrite/federationapi/storage/postgres"
"github.com/matrix-org/dendrite/federationapi/storage/sqlite3"
"github.com/matrix-org/dendrite/federationapi/storage/tables"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/stretchr/testify/assert"
)
func mustCreateInboundpeeksTable(t *testing.T, dbType test.DBType) (tables.FederationInboundPeeks, func()) {
connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := sqlutil.Open(&config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
}, sqlutil.NewExclusiveWriter())
if err != nil {
t.Fatalf("failed to open database: %s", err)
}
var tab tables.FederationInboundPeeks
switch dbType {
case test.DBTypePostgres:
tab, err = postgres.NewPostgresInboundPeeksTable(db)
case test.DBTypeSQLite:
tab, err = sqlite3.NewSQLiteInboundPeeksTable(db)
}
if err != nil {
t.Fatalf("failed to create table: %s", err)
}
return tab, close
}
func TestInboundPeeksTable(t *testing.T) {
ctx := context.Background()
alice := test.NewUser(t)
room := test.NewRoom(t, alice)
_, serverName, _ := gomatrixserverlib.SplitID('@', alice.ID)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
tab, closeDB := mustCreateInboundpeeksTable(t, dbType)
defer closeDB()
// Insert a peek
peekID := util.RandomString(8)
var renewalInterval int64 = 1000
if err := tab.InsertInboundPeek(ctx, nil, serverName, room.ID, peekID, renewalInterval); err != nil {
t.Fatal(err)
}
// select the newly inserted peek
inboundPeek1, err := tab.SelectInboundPeek(ctx, nil, serverName, room.ID, peekID)
if err != nil {
t.Fatal(err)
}
// Assert fields are set as expected
if inboundPeek1.PeekID != peekID {
t.Fatalf("unexpected inbound peek ID: %s, want %s", inboundPeek1.PeekID, peekID)
}
if inboundPeek1.RoomID != room.ID {
t.Fatalf("unexpected inbound peek room ID: %s, want %s", inboundPeek1.RoomID, peekID)
}
if inboundPeek1.ServerName != serverName {
t.Fatalf("unexpected inbound peek servername: %s, want %s", inboundPeek1.ServerName, serverName)
}
if inboundPeek1.RenewalInterval != renewalInterval {
t.Fatalf("unexpected inbound peek renewal interval: %d, want %d", inboundPeek1.RenewalInterval, renewalInterval)
}
// Renew the peek
if err = tab.RenewInboundPeek(ctx, nil, serverName, room.ID, peekID, 2000); err != nil {
t.Fatal(err)
}
// verify the values changed
inboundPeek2, err := tab.SelectInboundPeek(ctx, nil, serverName, room.ID, peekID)
if err != nil {
t.Fatal(err)
}
if reflect.DeepEqual(inboundPeek1, inboundPeek2) {
t.Fatal("expected a change peek, but they are the same")
}
if inboundPeek1.ServerName != inboundPeek2.ServerName {
t.Fatalf("unexpected servername change: %s -> %s", inboundPeek1.ServerName, inboundPeek2.ServerName)
}
if inboundPeek1.RoomID != inboundPeek2.RoomID {
t.Fatalf("unexpected roomID change: %s -> %s", inboundPeek1.RoomID, inboundPeek2.RoomID)
}
// delete the peek
if err = tab.DeleteInboundPeek(ctx, nil, serverName, room.ID, peekID); err != nil {
t.Fatal(err)
}
// There should be no peek anymore
peek, err := tab.SelectInboundPeek(ctx, nil, serverName, room.ID, peekID)
if err != nil {
t.Fatal(err)
}
if peek != nil {
t.Fatalf("got a peek which should be deleted: %+v", peek)
}
// insert some peeks
var peekIDs []string
for i := 0; i < 5; i++ {
peekID = util.RandomString(8)
if err = tab.InsertInboundPeek(ctx, nil, serverName, room.ID, peekID, 1000); err != nil {
t.Fatal(err)
}
peekIDs = append(peekIDs, peekID)
}
// Now select them
inboundPeeks, err := tab.SelectInboundPeeks(ctx, nil, room.ID)
if err != nil {
t.Fatal(err)
}
if len(inboundPeeks) != len(peekIDs) {
t.Fatalf("inserted %d peeks, selected %d", len(peekIDs), len(inboundPeeks))
}
gotPeekIDs := make([]string, 0, len(inboundPeeks))
for _, p := range inboundPeeks {
gotPeekIDs = append(gotPeekIDs, p.PeekID)
}
assert.ElementsMatch(t, gotPeekIDs, peekIDs)
// And delete them again
if err = tab.DeleteInboundPeeks(ctx, nil, room.ID); err != nil {
t.Fatal(err)
}
// they should be gone now
inboundPeeks, err = tab.SelectInboundPeeks(ctx, nil, room.ID)
if err != nil {
t.Fatal(err)
}
if len(inboundPeeks) > 0 {
t.Fatal("got inbound peeks which should be deleted")
}
})
}

View file

@ -28,7 +28,6 @@ type FederationQueuePDUs interface {
InsertQueuePDU(ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nid int64) error
DeleteQueuePDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error
SelectQueuePDUReferenceJSONCount(ctx context.Context, txn *sql.Tx, jsonNID int64) (int64, error)
SelectQueuePDUCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error)
SelectQueuePDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error)
SelectQueuePDUServerNames(ctx context.Context, txn *sql.Tx) ([]gomatrixserverlib.ServerName, error)
}
@ -38,7 +37,6 @@ type FederationQueueEDUs interface {
DeleteQueueEDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error
SelectQueueEDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error)
SelectQueueEDUReferenceJSONCount(ctx context.Context, txn *sql.Tx, jsonNID int64) (int64, error)
SelectQueueEDUCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error)
SelectQueueEDUServerNames(ctx context.Context, txn *sql.Tx) ([]gomatrixserverlib.ServerName, error)
SelectExpiredEDUs(ctx context.Context, txn *sql.Tx, expiredBefore gomatrixserverlib.Timestamp) ([]int64, error)
DeleteExpiredEDUs(ctx context.Context, txn *sql.Tx, expiredBefore gomatrixserverlib.Timestamp) error

View file

@ -0,0 +1,148 @@
package tables_test
import (
"context"
"reflect"
"testing"
"github.com/matrix-org/dendrite/federationapi/storage/postgres"
"github.com/matrix-org/dendrite/federationapi/storage/sqlite3"
"github.com/matrix-org/dendrite/federationapi/storage/tables"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/stretchr/testify/assert"
)
func mustCreateOutboundpeeksTable(t *testing.T, dbType test.DBType) (tables.FederationOutboundPeeks, func()) {
connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := sqlutil.Open(&config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
}, sqlutil.NewExclusiveWriter())
if err != nil {
t.Fatalf("failed to open database: %s", err)
}
var tab tables.FederationOutboundPeeks
switch dbType {
case test.DBTypePostgres:
tab, err = postgres.NewPostgresOutboundPeeksTable(db)
case test.DBTypeSQLite:
tab, err = sqlite3.NewSQLiteOutboundPeeksTable(db)
}
if err != nil {
t.Fatalf("failed to create table: %s", err)
}
return tab, close
}
func TestOutboundPeeksTable(t *testing.T) {
ctx := context.Background()
alice := test.NewUser(t)
room := test.NewRoom(t, alice)
_, serverName, _ := gomatrixserverlib.SplitID('@', alice.ID)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
tab, closeDB := mustCreateOutboundpeeksTable(t, dbType)
defer closeDB()
// Insert a peek
peekID := util.RandomString(8)
var renewalInterval int64 = 1000
if err := tab.InsertOutboundPeek(ctx, nil, serverName, room.ID, peekID, renewalInterval); err != nil {
t.Fatal(err)
}
// select the newly inserted peek
outboundPeek1, err := tab.SelectOutboundPeek(ctx, nil, serverName, room.ID, peekID)
if err != nil {
t.Fatal(err)
}
// Assert fields are set as expected
if outboundPeek1.PeekID != peekID {
t.Fatalf("unexpected outbound peek ID: %s, want %s", outboundPeek1.PeekID, peekID)
}
if outboundPeek1.RoomID != room.ID {
t.Fatalf("unexpected outbound peek room ID: %s, want %s", outboundPeek1.RoomID, peekID)
}
if outboundPeek1.ServerName != serverName {
t.Fatalf("unexpected outbound peek servername: %s, want %s", outboundPeek1.ServerName, serverName)
}
if outboundPeek1.RenewalInterval != renewalInterval {
t.Fatalf("unexpected outbound peek renewal interval: %d, want %d", outboundPeek1.RenewalInterval, renewalInterval)
}
// Renew the peek
if err = tab.RenewOutboundPeek(ctx, nil, serverName, room.ID, peekID, 2000); err != nil {
t.Fatal(err)
}
// verify the values changed
outboundPeek2, err := tab.SelectOutboundPeek(ctx, nil, serverName, room.ID, peekID)
if err != nil {
t.Fatal(err)
}
if reflect.DeepEqual(outboundPeek1, outboundPeek2) {
t.Fatal("expected a change peek, but they are the same")
}
if outboundPeek1.ServerName != outboundPeek2.ServerName {
t.Fatalf("unexpected servername change: %s -> %s", outboundPeek1.ServerName, outboundPeek2.ServerName)
}
if outboundPeek1.RoomID != outboundPeek2.RoomID {
t.Fatalf("unexpected roomID change: %s -> %s", outboundPeek1.RoomID, outboundPeek2.RoomID)
}
// delete the peek
if err = tab.DeleteOutboundPeek(ctx, nil, serverName, room.ID, peekID); err != nil {
t.Fatal(err)
}
// There should be no peek anymore
peek, err := tab.SelectOutboundPeek(ctx, nil, serverName, room.ID, peekID)
if err != nil {
t.Fatal(err)
}
if peek != nil {
t.Fatalf("got a peek which should be deleted: %+v", peek)
}
// insert some peeks
var peekIDs []string
for i := 0; i < 5; i++ {
peekID = util.RandomString(8)
if err = tab.InsertOutboundPeek(ctx, nil, serverName, room.ID, peekID, 1000); err != nil {
t.Fatal(err)
}
peekIDs = append(peekIDs, peekID)
}
// Now select them
outboundPeeks, err := tab.SelectOutboundPeeks(ctx, nil, room.ID)
if err != nil {
t.Fatal(err)
}
if len(outboundPeeks) != len(peekIDs) {
t.Fatalf("inserted %d peeks, selected %d", len(peekIDs), len(outboundPeeks))
}
gotPeekIDs := make([]string, 0, len(outboundPeeks))
for _, p := range outboundPeeks {
gotPeekIDs = append(gotPeekIDs, p.PeekID)
}
assert.ElementsMatch(t, gotPeekIDs, peekIDs)
// And delete them again
if err = tab.DeleteOutboundPeeks(ctx, nil, room.ID); err != nil {
t.Fatal(err)
}
// they should be gone now
outboundPeeks, err = tab.SelectOutboundPeeks(ctx, nil, room.ID)
if err != nil {
t.Fatal(err)
}
if len(outboundPeeks) > 0 {
t.Fatal("got outbound peeks which should be deleted")
}
})
}

View file

@ -198,7 +198,7 @@ func MakeExternalAPI(metricsName string, f func(*http.Request) util.JSONResponse
// MakeHTMLAPI adds Span metrics to the HTML Handler function
// This is used to serve HTML alongside JSON error messages
func MakeHTMLAPI(metricsName string, f func(http.ResponseWriter, *http.Request) *util.JSONResponse) http.Handler {
func MakeHTMLAPI(metricsName string, enableMetrics bool, f func(http.ResponseWriter, *http.Request) *util.JSONResponse) http.Handler {
withSpan := func(w http.ResponseWriter, req *http.Request) {
span := opentracing.StartSpan(metricsName)
defer span.Finish()
@ -211,6 +211,10 @@ func MakeHTMLAPI(metricsName string, f func(http.ResponseWriter, *http.Request)
}
}
if !enableMetrics {
return http.HandlerFunc(withSpan)
}
return promhttp.InstrumentHandlerCounter(
promauto.NewCounterVec(
prometheus.CounterOpts{

View file

@ -33,6 +33,11 @@ import (
"github.com/matrix-org/dendrite/setup/config"
)
// logrus is using a global variable when we're using `logrus.AddHook`
// this unfortunately results in us adding the same hook multiple times.
// This map ensures we only ever add one level hook.
var stdLevelLogAdded = make(map[logrus.Level]bool)
type utcFormatter struct {
logrus.Formatter
}

View file

@ -22,16 +22,16 @@ import (
"log/syslog"
"github.com/MFAshby/stdemuxerhook"
"github.com/matrix-org/dendrite/setup/config"
"github.com/sirupsen/logrus"
lSyslog "github.com/sirupsen/logrus/hooks/syslog"
"github.com/matrix-org/dendrite/setup/config"
)
// SetupHookLogging configures the logging hooks defined in the configuration.
// If something fails here it means that the logging was improperly configured,
// so we just exit with the error
func SetupHookLogging(hooks []config.LogrusHook, componentName string) {
stdLogAdded := false
for _, hook := range hooks {
// Check we received a proper logging level
level, err := logrus.ParseLevel(hook.Level)
@ -54,14 +54,11 @@ func SetupHookLogging(hooks []config.LogrusHook, componentName string) {
setupSyslogHook(hook, level, componentName)
case "std":
setupStdLogHook(level)
stdLogAdded = true
default:
logrus.Fatalf("Unrecognised logging hook type: %s", hook.Type)
}
}
if !stdLogAdded {
setupStdLogHook(logrus.InfoLevel)
}
setupStdLogHook(logrus.InfoLevel)
// Hooks are now configured for stdout/err, so throw away the default logger output
logrus.SetOutput(io.Discard)
}
@ -88,7 +85,11 @@ func checkSyslogHookParams(params map[string]interface{}) {
}
func setupStdLogHook(level logrus.Level) {
if stdLevelLogAdded[level] {
return
}
logrus.AddHook(&logLevelHook{level, stdemuxerhook.New(logrus.StandardLogger())})
stdLevelLogAdded[level] = true
}
func setupSyslogHook(hook config.LogrusHook, level logrus.Level, componentName string) {

44
internal/validate.go Normal file
View file

@ -0,0 +1,44 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package internal
import (
"fmt"
"net/http"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/util"
)
const minPasswordLength = 8 // http://matrix.org/docs/spec/client_server/r0.2.0.html#password-based
const maxPasswordLength = 512 // https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161
// ValidatePassword returns an error response if the password is invalid
func ValidatePassword(password string) *util.JSONResponse {
// https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161
if len(password) > maxPasswordLength {
return &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON(fmt.Sprintf("password too long: max %d characters", maxPasswordLength)),
}
} else if len(password) > 0 && len(password) < minPasswordLength {
return &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.WeakPassword(fmt.Sprintf("password too weak: min %d chars", minPasswordLength)),
}
}
return nil
}

View file

@ -24,6 +24,8 @@ import (
"sync"
"time"
rsapi "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrix"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
@ -102,6 +104,7 @@ type DeviceListUpdater struct {
// block on or timeout via a select.
userIDToChan map[string]chan bool
userIDToChanMu *sync.Mutex
rsAPI rsapi.KeyserverRoomserverAPI
}
// DeviceListUpdaterDatabase is the subset of functionality from storage.Database required for the updater.
@ -124,6 +127,8 @@ type DeviceListUpdaterDatabase interface {
// DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced.
DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
DeleteStaleDeviceLists(ctx context.Context, userIDs []string) error
}
type DeviceListUpdaterAPI interface {
@ -140,7 +145,7 @@ func NewDeviceListUpdater(
process *process.ProcessContext, db DeviceListUpdaterDatabase,
api DeviceListUpdaterAPI, producer KeyChangeProducer,
fedClient fedsenderapi.KeyserverFederationAPI, numWorkers int,
thisServer gomatrixserverlib.ServerName,
rsAPI rsapi.KeyserverRoomserverAPI, thisServer gomatrixserverlib.ServerName,
) *DeviceListUpdater {
return &DeviceListUpdater{
process: process,
@ -154,6 +159,7 @@ func NewDeviceListUpdater(
workerChans: make([]chan gomatrixserverlib.ServerName, numWorkers),
userIDToChan: make(map[string]chan bool),
userIDToChanMu: &sync.Mutex{},
rsAPI: rsAPI,
}
}
@ -168,7 +174,7 @@ func (u *DeviceListUpdater) Start() error {
go u.worker(ch)
}
staleLists, err := u.db.StaleDeviceLists(context.Background(), []gomatrixserverlib.ServerName{})
staleLists, err := u.db.StaleDeviceLists(u.process.Context(), []gomatrixserverlib.ServerName{})
if err != nil {
return err
}
@ -186,6 +192,25 @@ func (u *DeviceListUpdater) Start() error {
return nil
}
// CleanUp removes stale device entries for users we don't share a room with anymore
func (u *DeviceListUpdater) CleanUp() error {
staleUsers, err := u.db.StaleDeviceLists(u.process.Context(), []gomatrixserverlib.ServerName{})
if err != nil {
return err
}
res := rsapi.QueryLeftUsersResponse{}
if err = u.rsAPI.QueryLeftUsers(u.process.Context(), &rsapi.QueryLeftUsersRequest{StaleDeviceListUsers: staleUsers}, &res); err != nil {
return err
}
if len(res.LeftUsers) == 0 {
return nil
}
logrus.Debugf("Deleting %d stale device list entries", len(res.LeftUsers))
return u.db.DeleteStaleDeviceLists(u.process.Context(), res.LeftUsers)
}
func (u *DeviceListUpdater) mutex(userID string) *sync.Mutex {
u.mu.Lock()
defer u.mu.Unlock()

View file

@ -30,7 +30,12 @@ import (
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/storage"
roomserver "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig"
)
var (
@ -53,6 +58,10 @@ type mockDeviceListUpdaterDatabase struct {
mu sync.Mutex // protect staleUsers
}
func (d *mockDeviceListUpdaterDatabase) DeleteStaleDeviceLists(ctx context.Context, userIDs []string) error {
return nil
}
// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists.
// If no domains are given, all user IDs with stale device lists are returned.
func (d *mockDeviceListUpdaterDatabase) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
@ -153,7 +162,7 @@ func TestUpdateHavePrevID(t *testing.T) {
}
ap := &mockDeviceListUpdaterAPI{}
producer := &mockKeyChangeProducer{}
updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, nil, 1, "localhost")
updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, nil, 1, nil, "localhost")
event := gomatrixserverlib.DeviceListUpdateEvent{
DeviceDisplayName: "Foo Bar",
Deleted: false,
@ -225,7 +234,7 @@ func TestUpdateNoPrevID(t *testing.T) {
`)),
}, nil
})
updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 2, "example.test")
updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 2, nil, "example.test")
if err := updater.Start(); err != nil {
t.Fatalf("failed to start updater: %s", err)
}
@ -239,6 +248,7 @@ func TestUpdateNoPrevID(t *testing.T) {
UserID: remoteUserID,
}
err := updater.Update(ctx, event)
if err != nil {
t.Fatalf("Update returned an error: %s", err)
}
@ -294,7 +304,7 @@ func TestDebounce(t *testing.T) {
close(incomingFedReq)
return <-fedCh, nil
})
updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 1, "localhost")
updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 1, nil, "localhost")
if err := updater.Start(); err != nil {
t.Fatalf("failed to start updater: %s", err)
}
@ -349,3 +359,73 @@ func TestDebounce(t *testing.T) {
t.Errorf("user %s is marked as stale", userID)
}
}
func mustCreateKeyserverDB(t *testing.T, dbType test.DBType) (storage.Database, func()) {
t.Helper()
base, _, _ := testrig.Base(nil)
connStr, clearDB := test.PrepareDBConnectionString(t, dbType)
db, err := storage.NewDatabase(base, &config.DatabaseOptions{ConnectionString: config.DataSource(connStr)})
if err != nil {
t.Fatal(err)
}
return db, clearDB
}
type mockKeyserverRoomserverAPI struct {
leftUsers []string
}
func (m *mockKeyserverRoomserverAPI) QueryLeftUsers(ctx context.Context, req *roomserver.QueryLeftUsersRequest, res *roomserver.QueryLeftUsersResponse) error {
res.LeftUsers = m.leftUsers
return nil
}
func TestDeviceListUpdater_CleanUp(t *testing.T) {
processCtx := process.NewProcessContext()
alice := test.NewUser(t)
bob := test.NewUser(t)
// Bob is not joined to any of our rooms
rsAPI := &mockKeyserverRoomserverAPI{leftUsers: []string{bob.ID}}
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, clearDB := mustCreateKeyserverDB(t, dbType)
defer clearDB()
// This should not get deleted
if err := db.MarkDeviceListStale(processCtx.Context(), alice.ID, true); err != nil {
t.Error(err)
}
// this one should get deleted
if err := db.MarkDeviceListStale(processCtx.Context(), bob.ID, true); err != nil {
t.Error(err)
}
updater := NewDeviceListUpdater(processCtx, db, nil,
nil, nil,
0, rsAPI, "test")
if err := updater.CleanUp(); err != nil {
t.Error(err)
}
// check that we still have Alice in our stale list
staleUsers, err := db.StaleDeviceLists(ctx, []gomatrixserverlib.ServerName{"test"})
if err != nil {
t.Error(err)
}
// There should only be Alice
wantCount := 1
if count := len(staleUsers); count != wantCount {
t.Fatalf("expected there to be %d stale device lists, got %d", wantCount, count)
}
if staleUsers[0] != alice.ID {
t.Fatalf("unexpected stale device list user: %s, want %s", staleUsers[0], alice.ID)
}
})
}

View file

@ -18,6 +18,8 @@ import (
"github.com/gorilla/mux"
"github.com/sirupsen/logrus"
rsapi "github.com/matrix-org/dendrite/roomserver/api"
fedsenderapi "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/consumers"
@ -40,6 +42,7 @@ func AddInternalRoutes(router *mux.Router, intAPI api.KeyInternalAPI, enableMetr
// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes.
func NewInternalAPI(
base *base.BaseDendrite, cfg *config.KeyServer, fedClient fedsenderapi.KeyserverFederationAPI,
rsAPI rsapi.KeyserverRoomserverAPI,
) api.KeyInternalAPI {
js, _ := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream)
@ -47,6 +50,7 @@ func NewInternalAPI(
if err != nil {
logrus.WithError(err).Panicf("failed to connect to key server database")
}
keyChangeProducer := &producers.KeyChange{
Topic: string(cfg.Matrix.JetStream.Prefixed(jetstream.OutputKeyChangeEvent)),
JetStream: js,
@ -58,8 +62,14 @@ func NewInternalAPI(
FedClient: fedClient,
Producer: keyChangeProducer,
}
updater := internal.NewDeviceListUpdater(base.ProcessContext, db, ap, keyChangeProducer, fedClient, 8, cfg.Matrix.ServerName) // 8 workers TODO: configurable
updater := internal.NewDeviceListUpdater(base.ProcessContext, db, ap, keyChangeProducer, fedClient, 8, rsAPI, cfg.Matrix.ServerName) // 8 workers TODO: configurable
ap.Updater = updater
// Remove users which we don't share a room with anymore
if err := updater.CleanUp(); err != nil {
logrus.WithError(err).Error("failed to cleanup stale device lists")
}
go func() {
if err := updater.Start(); err != nil {
logrus.WithError(err).Panicf("failed to start device list updater")

View file

@ -0,0 +1,29 @@
package keyserver
import (
"context"
"testing"
roomserver "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig"
)
type mockKeyserverRoomserverAPI struct {
leftUsers []string
}
func (m *mockKeyserverRoomserverAPI) QueryLeftUsers(ctx context.Context, req *roomserver.QueryLeftUsersRequest, res *roomserver.QueryLeftUsersResponse) error {
res.LeftUsers = m.leftUsers
return nil
}
// Merely tests that we can create an internal keyserver API
func Test_NewInternalAPI(t *testing.T) {
rsAPI := &mockKeyserverRoomserverAPI{}
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
base, closeBase := testrig.CreateBaseDendrite(t, dbType)
defer closeBase()
_ = NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI)
})
}

View file

@ -85,4 +85,9 @@ type Database interface {
StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error
StoreCrossSigningSigsForTarget(ctx context.Context, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error
DeleteStaleDeviceLists(
ctx context.Context,
userIDs []string,
) error
}

View file

@ -19,6 +19,10 @@ import (
"database/sql"
"time"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/keyserver/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
@ -48,10 +52,14 @@ const selectStaleDeviceListsWithDomainsSQL = "" +
const selectStaleDeviceListsSQL = "" +
"SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 ORDER BY ts_added_secs DESC"
const deleteStaleDevicesSQL = "" +
"DELETE FROM keyserver_stale_device_lists WHERE user_id = ANY($1)"
type staleDeviceListsStatements struct {
upsertStaleDeviceListStmt *sql.Stmt
selectStaleDeviceListsWithDomainsStmt *sql.Stmt
selectStaleDeviceListsStmt *sql.Stmt
deleteStaleDeviceListsStmt *sql.Stmt
}
func NewPostgresStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) {
@ -60,16 +68,12 @@ func NewPostgresStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, erro
if err != nil {
return nil, err
}
if s.upsertStaleDeviceListStmt, err = db.Prepare(upsertStaleDeviceListSQL); err != nil {
return nil, err
}
if s.selectStaleDeviceListsStmt, err = db.Prepare(selectStaleDeviceListsSQL); err != nil {
return nil, err
}
if s.selectStaleDeviceListsWithDomainsStmt, err = db.Prepare(selectStaleDeviceListsWithDomainsSQL); err != nil {
return nil, err
}
return s, nil
return s, sqlutil.StatementList{
{&s.upsertStaleDeviceListStmt, upsertStaleDeviceListSQL},
{&s.selectStaleDeviceListsStmt, selectStaleDeviceListsSQL},
{&s.selectStaleDeviceListsWithDomainsStmt, selectStaleDeviceListsWithDomainsSQL},
{&s.deleteStaleDeviceListsStmt, deleteStaleDevicesSQL},
}.Prepare(db)
}
func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error {
@ -105,6 +109,15 @@ func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx conte
return result, nil
}
// DeleteStaleDeviceLists removes users from stale device lists
func (s *staleDeviceListsStatements) DeleteStaleDeviceLists(
ctx context.Context, txn *sql.Tx, userIDs []string,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteStaleDeviceListsStmt)
_, err := stmt.ExecContext(ctx, pq.Array(userIDs))
return err
}
func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) {
defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed")
for rows.Next() {

View file

@ -249,3 +249,13 @@ func (d *Database) StoreCrossSigningSigsForTarget(
return nil
})
}
// DeleteStaleDeviceLists deletes stale device list entries for users we don't share a room with anymore.
func (d *Database) DeleteStaleDeviceLists(
ctx context.Context,
userIDs []string,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.StaleDeviceListsTable.DeleteStaleDeviceLists(ctx, txn, userIDs)
})
}

View file

@ -17,8 +17,11 @@ package sqlite3
import (
"context"
"database/sql"
"strings"
"time"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/keyserver/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
@ -48,11 +51,15 @@ const selectStaleDeviceListsWithDomainsSQL = "" +
const selectStaleDeviceListsSQL = "" +
"SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 ORDER BY ts_added_secs DESC"
const deleteStaleDevicesSQL = "" +
"DELETE FROM keyserver_stale_device_lists WHERE user_id IN ($1)"
type staleDeviceListsStatements struct {
db *sql.DB
upsertStaleDeviceListStmt *sql.Stmt
selectStaleDeviceListsWithDomainsStmt *sql.Stmt
selectStaleDeviceListsStmt *sql.Stmt
// deleteStaleDeviceListsStmt *sql.Stmt // Prepared at runtime
}
func NewSqliteStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) {
@ -63,16 +70,12 @@ func NewSqliteStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error)
if err != nil {
return nil, err
}
if s.upsertStaleDeviceListStmt, err = db.Prepare(upsertStaleDeviceListSQL); err != nil {
return nil, err
}
if s.selectStaleDeviceListsStmt, err = db.Prepare(selectStaleDeviceListsSQL); err != nil {
return nil, err
}
if s.selectStaleDeviceListsWithDomainsStmt, err = db.Prepare(selectStaleDeviceListsWithDomainsSQL); err != nil {
return nil, err
}
return s, nil
return s, sqlutil.StatementList{
{&s.upsertStaleDeviceListStmt, upsertStaleDeviceListSQL},
{&s.selectStaleDeviceListsStmt, selectStaleDeviceListsSQL},
{&s.selectStaleDeviceListsWithDomainsStmt, selectStaleDeviceListsWithDomainsSQL},
// { &s.deleteStaleDeviceListsStmt, deleteStaleDevicesSQL}, // Prepared at runtime
}.Prepare(db)
}
func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error {
@ -108,6 +111,27 @@ func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx conte
return result, nil
}
// DeleteStaleDeviceLists removes users from stale device lists
func (s *staleDeviceListsStatements) DeleteStaleDeviceLists(
ctx context.Context, txn *sql.Tx, userIDs []string,
) error {
qry := strings.Replace(deleteStaleDevicesSQL, "($1)", sqlutil.QueryVariadic(len(userIDs)), 1)
stmt, err := s.db.Prepare(qry)
if err != nil {
return err
}
defer internal.CloseAndLogIfError(ctx, stmt, "DeleteStaleDeviceLists: stmt.Close failed")
stmt = sqlutil.TxStmt(txn, stmt)
params := make([]any, len(userIDs))
for i := range userIDs {
params[i] = userIDs[i]
}
_, err = stmt.ExecContext(ctx, params...)
return err
}
func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) {
defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed")
for rows.Next() {

View file

@ -56,6 +56,7 @@ type KeyChanges interface {
type StaleDeviceLists interface {
InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error
SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error)
DeleteStaleDeviceLists(ctx context.Context, txn *sql.Tx, userIDs []string) error
}
type CrossSigningKeys interface {

View file

@ -0,0 +1,94 @@
package tables_test
import (
"context"
"testing"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/storage/sqlite3"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/keyserver/storage/postgres"
"github.com/matrix-org/dendrite/keyserver/storage/tables"
"github.com/matrix-org/dendrite/test"
)
func mustCreateTable(t *testing.T, dbType test.DBType) (tab tables.StaleDeviceLists, close func()) {
connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := sqlutil.Open(&config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
}, nil)
if err != nil {
t.Fatalf("failed to open database: %s", err)
}
switch dbType {
case test.DBTypePostgres:
tab, err = postgres.NewPostgresStaleDeviceListsTable(db)
case test.DBTypeSQLite:
tab, err = sqlite3.NewSqliteStaleDeviceListsTable(db)
}
if err != nil {
t.Fatalf("failed to create new table: %s", err)
}
return tab, close
}
func TestStaleDeviceLists(t *testing.T) {
alice := test.NewUser(t)
bob := test.NewUser(t)
charlie := "@charlie:localhost"
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
tab, closeDB := mustCreateTable(t, dbType)
defer closeDB()
if err := tab.InsertStaleDeviceList(ctx, alice.ID, true); err != nil {
t.Fatalf("failed to insert stale device: %s", err)
}
if err := tab.InsertStaleDeviceList(ctx, bob.ID, true); err != nil {
t.Fatalf("failed to insert stale device: %s", err)
}
if err := tab.InsertStaleDeviceList(ctx, charlie, true); err != nil {
t.Fatalf("failed to insert stale device: %s", err)
}
// Query one server
wantStaleUsers := []string{alice.ID, bob.ID}
gotStaleUsers, err := tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{"test"})
if err != nil {
t.Fatalf("failed to query stale device lists: %s", err)
}
if !test.UnsortedStringSliceEqual(wantStaleUsers, gotStaleUsers) {
t.Fatalf("expected stale users %v, got %v", wantStaleUsers, gotStaleUsers)
}
// Query all servers
wantStaleUsers = []string{alice.ID, bob.ID, charlie}
gotStaleUsers, err = tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{})
if err != nil {
t.Fatalf("failed to query stale device lists: %s", err)
}
if !test.UnsortedStringSliceEqual(wantStaleUsers, gotStaleUsers) {
t.Fatalf("expected stale users %v, got %v", wantStaleUsers, gotStaleUsers)
}
// Delete stale devices
deleteUsers := []string{alice.ID, bob.ID}
if err = tab.DeleteStaleDeviceLists(ctx, nil, deleteUsers); err != nil {
t.Fatalf("failed to delete stale device lists: %s", err)
}
// Verify we don't get anything back after deleting
gotStaleUsers, err = tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{"test"})
if err != nil {
t.Fatalf("failed to query stale device lists: %s", err)
}
if gotCount := len(gotStaleUsers); gotCount > 0 {
t.Fatalf("expected no stale users, got %d", gotCount)
}
})
}

View file

@ -17,6 +17,7 @@ type RoomserverInternalAPI interface {
ClientRoomserverAPI
UserRoomserverAPI
FederationRoomserverAPI
KeyserverRoomserverAPI
// needed to avoid chicken and egg scenario when setting up the
// interdependencies between the roomserver and other input APIs
@ -199,3 +200,7 @@ type FederationRoomserverAPI interface {
// Query a given amount (or less) of events prior to a given set of events.
PerformBackfill(ctx context.Context, req *PerformBackfillRequest, res *PerformBackfillResponse) error
}
type KeyserverRoomserverAPI interface {
QueryLeftUsers(ctx context.Context, req *QueryLeftUsersRequest, res *QueryLeftUsersResponse) error
}

View file

@ -19,6 +19,12 @@ type RoomserverInternalAPITrace struct {
Impl RoomserverInternalAPI
}
func (t *RoomserverInternalAPITrace) QueryLeftUsers(ctx context.Context, req *QueryLeftUsersRequest, res *QueryLeftUsersResponse) error {
err := t.Impl.QueryLeftUsers(ctx, req, res)
util.GetLogger(ctx).WithError(err).Infof("QueryLeftUsers req=%+v res=%+v", js(req), js(res))
return err
}
func (t *RoomserverInternalAPITrace) SetFederationAPI(fsAPI fsAPI.RoomserverFederationAPI, keyRing *gomatrixserverlib.KeyRing) {
t.Impl.SetFederationAPI(fsAPI, keyRing)
}

View file

@ -78,6 +78,7 @@ const (
type PerformJoinRequest struct {
RoomIDOrAlias string `json:"room_id_or_alias"`
UserID string `json:"user_id"`
IsGuest bool `json:"is_guest"`
Content map[string]interface{} `json:"content"`
ServerNames []gomatrixserverlib.ServerName `json:"server_names"`
Unsigned map[string]interface{} `json:"unsigned"`

View file

@ -447,3 +447,15 @@ type QueryMembershipAtEventResponse struct {
// do not have known state will return an empty array here.
Memberships map[string][]*gomatrixserverlib.HeaderedEvent `json:"memberships"`
}
// QueryLeftUsersRequest is a request to calculate users that we (the server) don't share a
// a room with anymore. This is used to cleanup stale device list entries, where we would
// otherwise keep on trying to get device lists.
type QueryLeftUsersRequest struct {
StaleDeviceListUsers []string `json:"user_ids"`
}
// QueryLeftUsersResponse is the response to QueryLeftUsersRequest.
type QueryLeftUsersResponse struct {
LeftUsers []string `json:"user_ids"`
}

View file

@ -4,6 +4,10 @@ import (
"context"
"github.com/getsentry/sentry-go"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
"github.com/sirupsen/logrus"
asAPI "github.com/matrix-org/dendrite/appservice/api"
fsAPI "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/internal/caching"
@ -19,9 +23,6 @@ import (
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
"github.com/sirupsen/logrus"
)
// RoomserverInternalAPI is an implementation of api.RoomserverInternalAPI
@ -104,6 +105,11 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio
r.fsAPI = fsAPI
r.KeyRing = keyRing
identity, err := r.Cfg.Matrix.SigningIdentityFor(r.ServerName)
if err != nil {
logrus.Panic(err)
}
r.Inputer = &input.Inputer{
Cfg: &r.Base.Cfg.RoomServer,
Base: r.Base,
@ -114,7 +120,8 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio
JetStream: r.JetStream,
NATSClient: r.NATSClient,
Durable: nats.Durable(r.Durable),
ServerName: r.Cfg.Matrix.ServerName,
ServerName: r.ServerName,
SigningIdentity: identity,
FSAPI: fsAPI,
KeyRing: keyRing,
ACLs: r.ServerACLs,
@ -135,7 +142,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio
Queryer: r.Queryer,
}
r.Peeker = &perform.Peeker{
ServerName: r.Cfg.Matrix.ServerName,
ServerName: r.ServerName,
Cfg: r.Cfg,
DB: r.DB,
FSAPI: r.fsAPI,
@ -146,7 +153,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio
Inputer: r.Inputer,
}
r.Unpeeker = &perform.Unpeeker{
ServerName: r.Cfg.Matrix.ServerName,
ServerName: r.ServerName,
Cfg: r.Cfg,
DB: r.DB,
FSAPI: r.fsAPI,
@ -193,6 +200,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio
func (r *RoomserverInternalAPI) SetUserAPI(userAPI userapi.RoomserverUserAPI) {
r.Leaver.UserAPI = userAPI
r.Inputer.UserAPI = userAPI
}
func (r *RoomserverInternalAPI) SetAppserviceAPI(asAPI asAPI.AppServiceInternalAPI) {

View file

@ -23,6 +23,8 @@ import (
"sync"
"time"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/Arceliar/phony"
"github.com/getsentry/sentry-go"
"github.com/matrix-org/gomatrixserverlib"
@ -79,6 +81,7 @@ type Inputer struct {
JetStream nats.JetStreamContext
Durable nats.SubOpt
ServerName gomatrixserverlib.ServerName
SigningIdentity *gomatrixserverlib.SigningIdentity
FSAPI fedapi.RoomserverFederationAPI
KeyRing gomatrixserverlib.JSONVerifier
ACLs *acls.ServerACLs
@ -87,6 +90,7 @@ type Inputer struct {
workers sync.Map // room ID -> *worker
Queryer *query.Queryer
UserAPI userapi.RoomserverUserAPI
}
// If a room consumer is inactive for a while then we will allow NATS

View file

@ -19,6 +19,7 @@ package input
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"time"
@ -31,6 +32,8 @@ import (
"github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus"
userAPI "github.com/matrix-org/dendrite/userapi/api"
fedapi "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/eventutil"
@ -440,6 +443,13 @@ func (r *Inputer) processRoomEvent(
}
}
// If guest_access changed and is not can_join, kick all guest users.
if event.Type() == gomatrixserverlib.MRoomGuestAccess && gjson.GetBytes(event.Content(), "guest_access").Str != "can_join" {
if err = r.kickGuests(ctx, event, roomInfo); err != nil {
logrus.WithError(err).Error("failed to kick guest users on m.room.guest_access revocation")
}
}
// Everything was OK — the latest events updater didn't error and
// we've sent output events. Finally, generate a hook call.
hooks.Run(hooks.KindNewEventPersisted, headered)
@ -729,3 +739,98 @@ func (r *Inputer) calculateAndSetState(
succeeded = true
return nil
}
// kickGuests kicks guests users from m.room.guest_access rooms, if guest access is now prohibited.
func (r *Inputer) kickGuests(ctx context.Context, event *gomatrixserverlib.Event, roomInfo *types.RoomInfo) error {
membershipNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, true)
if err != nil {
return err
}
memberEvents, err := r.DB.Events(ctx, membershipNIDs)
if err != nil {
return err
}
inputEvents := make([]api.InputRoomEvent, 0, len(memberEvents))
latestReq := &api.QueryLatestEventsAndStateRequest{
RoomID: event.RoomID(),
}
latestRes := &api.QueryLatestEventsAndStateResponse{}
if err = r.Queryer.QueryLatestEventsAndState(ctx, latestReq, latestRes); err != nil {
return err
}
prevEvents := latestRes.LatestEvents
for _, memberEvent := range memberEvents {
if memberEvent.StateKey() == nil {
continue
}
localpart, senderDomain, err := gomatrixserverlib.SplitID('@', *memberEvent.StateKey())
if err != nil {
continue
}
accountRes := &userAPI.QueryAccountByLocalpartResponse{}
if err = r.UserAPI.QueryAccountByLocalpart(ctx, &userAPI.QueryAccountByLocalpartRequest{
Localpart: localpart,
ServerName: senderDomain,
}, accountRes); err != nil {
return err
}
if accountRes.Account == nil {
continue
}
if accountRes.Account.AccountType != userAPI.AccountTypeGuest {
continue
}
var memberContent gomatrixserverlib.MemberContent
if err = json.Unmarshal(memberEvent.Content(), &memberContent); err != nil {
return err
}
memberContent.Membership = gomatrixserverlib.Leave
stateKey := *memberEvent.StateKey()
fledglingEvent := &gomatrixserverlib.EventBuilder{
RoomID: event.RoomID(),
Type: gomatrixserverlib.MRoomMember,
StateKey: &stateKey,
Sender: stateKey,
PrevEvents: prevEvents,
}
if fledglingEvent.Content, err = json.Marshal(memberContent); err != nil {
return err
}
eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(fledglingEvent)
if err != nil {
return err
}
event, err := eventutil.BuildEvent(ctx, fledglingEvent, r.Cfg.Matrix, r.SigningIdentity, time.Now(), &eventsNeeded, latestRes)
if err != nil {
return err
}
inputEvents = append(inputEvents, api.InputRoomEvent{
Kind: api.KindNew,
Event: event,
Origin: senderDomain,
SendAsServer: string(senderDomain),
})
prevEvents = []gomatrixserverlib.EventReference{
event.EventReference(),
}
}
inputReq := &api.InputRoomEventsRequest{
InputRoomEvents: inputEvents,
Asynchronous: true, // Needs to be async, as we otherwise create a deadlock
}
inputRes := &api.InputRoomEventsResponse{}
return r.InputRoomEvents(ctx, inputReq, inputRes)
}

View file

@ -16,6 +16,7 @@ package perform
import (
"context"
"database/sql"
"errors"
"fmt"
"strings"
@ -270,6 +271,28 @@ func (r *Joiner) performJoinRoomByID(
}
}
// If a guest is trying to join a room, check that the room has a m.room.guest_access event
if req.IsGuest {
var guestAccessEvent *gomatrixserverlib.HeaderedEvent
guestAccess := "forbidden"
guestAccessEvent, err = r.DB.GetStateEvent(ctx, req.RoomIDOrAlias, gomatrixserverlib.MRoomGuestAccess, "")
if (err != nil && !errors.Is(err, sql.ErrNoRows)) || guestAccessEvent == nil {
logrus.WithError(err).Warn("unable to get m.room.guest_access event, defaulting to 'forbidden'")
}
if guestAccessEvent != nil {
guestAccess = gjson.GetBytes(guestAccessEvent.Content(), "guest_access").String()
}
// Servers MUST only allow guest users to join rooms if the m.room.guest_access state event
// is present on the room and has the guest_access value can_join.
if guestAccess != "can_join" {
return "", "", &rsAPI.PerformError{
Code: rsAPI.PerformErrorNotAllowed,
Msg: "Guest access is forbidden",
}
}
}
// If we should do a forced federated join then do that.
var joinedVia gomatrixserverlib.ServerName
if forceFederatedJoin {

View file

@ -805,6 +805,12 @@ func (r *Queryer) QueryBulkStateContent(ctx context.Context, req *api.QueryBulkS
return nil
}
func (r *Queryer) QueryLeftUsers(ctx context.Context, req *api.QueryLeftUsersRequest, res *api.QueryLeftUsersResponse) error {
var err error
res.LeftUsers, err = r.DB.GetLeftUsers(ctx, req.StaleDeviceListUsers)
return err
}
func (r *Queryer) QuerySharedUsers(ctx context.Context, req *api.QuerySharedUsersRequest, res *api.QuerySharedUsersResponse) error {
roomIDs, err := r.DB.GetRoomsByMembership(ctx, req.UserID, "join")
if err != nil {

View file

@ -63,6 +63,7 @@ const (
RoomserverQueryAuthChainPath = "/roomserver/queryAuthChain"
RoomserverQueryRestrictedJoinAllowed = "/roomserver/queryRestrictedJoinAllowed"
RoomserverQueryMembershipAtEventPath = "/roomserver/queryMembershipAtEvent"
RoomserverQueryLeftMembersPath = "/roomserver/queryLeftMembers"
)
type httpRoomserverInternalAPI struct {
@ -553,3 +554,10 @@ func (h *httpRoomserverInternalAPI) QueryMembershipAtEvent(ctx context.Context,
h.httpClient, ctx, request, response,
)
}
func (h *httpRoomserverInternalAPI) QueryLeftUsers(ctx context.Context, request *api.QueryLeftUsersRequest, response *api.QueryLeftUsersResponse) error {
return httputil.CallInternalRPCAPI(
"RoomserverQueryLeftMembers", h.roomserverURL+RoomserverQueryLeftMembersPath,
h.httpClient, ctx, request, response,
)
}

View file

@ -203,4 +203,9 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router, enableMe
RoomserverQueryMembershipAtEventPath,
httputil.MakeInternalRPCAPI("RoomserverQueryMembershipAtEventPath", enableMetrics, r.QueryMembershipAtEvent),
)
internalAPIMux.Handle(
RoomserverQueryLeftMembersPath,
httputil.MakeInternalRPCAPI("RoomserverQueryLeftMembersPath", enableMetrics, r.QueryLeftUsers),
)
}

View file

@ -2,27 +2,163 @@ package roomserver_test
import (
"context"
"net/http"
"reflect"
"testing"
"time"
"github.com/gorilla/mux"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/userapi"
userAPI "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/roomserver"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/inthttp"
"github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig"
"github.com/matrix-org/gomatrixserverlib"
)
func mustCreateDatabase(t *testing.T, dbType test.DBType) (*base.BaseDendrite, storage.Database, func()) {
t.Helper()
base, close := testrig.CreateBaseDendrite(t, dbType)
db, err := storage.Open(base, &base.Cfg.KeyServer.Database, base.Caches)
db, err := storage.Open(base, &base.Cfg.RoomServer.Database, base.Caches)
if err != nil {
t.Fatalf("failed to create Database: %v", err)
}
return base, db, close
}
func Test_SharedUsers(t *testing.T) {
func TestUsers(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
base, close := testrig.CreateBaseDendrite(t, dbType)
defer close()
rsAPI := roomserver.NewInternalAPI(base)
// SetFederationAPI starts the room event input consumer
rsAPI.SetFederationAPI(nil, nil)
t.Run("shared users", func(t *testing.T) {
testSharedUsers(t, rsAPI)
})
t.Run("kick users", func(t *testing.T) {
usrAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, nil, rsAPI, nil)
rsAPI.SetUserAPI(usrAPI)
testKickUsers(t, rsAPI, usrAPI)
})
})
}
func testSharedUsers(t *testing.T, rsAPI api.RoomserverInternalAPI) {
alice := test.NewUser(t)
bob := test.NewUser(t)
room := test.NewRoom(t, alice, test.RoomPreset(test.PresetTrustedPrivateChat))
// Invite and join Bob
room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{
"membership": "invite",
}, test.WithStateKey(bob.ID))
room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{
"membership": "join",
}, test.WithStateKey(bob.ID))
ctx := context.Background()
// Create the room
if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil {
t.Errorf("failed to send events: %v", err)
}
// Query the shared users for Alice, there should only be Bob.
// This is used by the SyncAPI keychange consumer.
res := &api.QuerySharedUsersResponse{}
if err := rsAPI.QuerySharedUsers(ctx, &api.QuerySharedUsersRequest{UserID: alice.ID}, res); err != nil {
t.Errorf("unable to query known users: %v", err)
}
if _, ok := res.UserIDsToCount[bob.ID]; !ok {
t.Errorf("expected to find %s in shared users, but didn't: %+v", bob.ID, res.UserIDsToCount)
}
// Also verify that we get the expected result when specifying OtherUserIDs.
// This is used by the SyncAPI when getting device list changes.
if err := rsAPI.QuerySharedUsers(ctx, &api.QuerySharedUsersRequest{UserID: alice.ID, OtherUserIDs: []string{bob.ID}}, res); err != nil {
t.Errorf("unable to query known users: %v", err)
}
if _, ok := res.UserIDsToCount[bob.ID]; !ok {
t.Errorf("expected to find %s in shared users, but didn't: %+v", bob.ID, res.UserIDsToCount)
}
}
func testKickUsers(t *testing.T, rsAPI api.RoomserverInternalAPI, usrAPI userAPI.UserInternalAPI) {
// Create users and room; Bob is going to be the guest and kicked on revocation of guest access
alice := test.NewUser(t, test.WithAccountType(userAPI.AccountTypeUser))
bob := test.NewUser(t, test.WithAccountType(userAPI.AccountTypeGuest))
room := test.NewRoom(t, alice, test.RoomPreset(test.PresetPublicChat), test.GuestsCanJoin(true))
// Join with the guest user
room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{
"membership": "join",
}, test.WithStateKey(bob.ID))
ctx := context.Background()
// Create the users in the userapi, so the RSAPI can query the account type later
for _, u := range []*test.User{alice, bob} {
localpart, serverName, _ := gomatrixserverlib.SplitID('@', u.ID)
userRes := &userAPI.PerformAccountCreationResponse{}
if err := usrAPI.PerformAccountCreation(ctx, &userAPI.PerformAccountCreationRequest{
AccountType: u.AccountType,
Localpart: localpart,
ServerName: serverName,
Password: "someRandomPassword",
}, userRes); err != nil {
t.Errorf("failed to create account: %s", err)
}
}
// Create the room in the database
if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil {
t.Errorf("failed to send events: %v", err)
}
// Get the membership events BEFORE revoking guest access
membershipRes := &api.QueryMembershipsForRoomResponse{}
if err := rsAPI.QueryMembershipsForRoom(ctx, &api.QueryMembershipsForRoomRequest{LocalOnly: true, JoinedOnly: true, RoomID: room.ID}, membershipRes); err != nil {
t.Errorf("failed to query membership for room: %s", err)
}
// revoke guest access
revokeEvent := room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomGuestAccess, map[string]string{"guest_access": "forbidden"}, test.WithStateKey(""))
if err := api.SendEvents(ctx, rsAPI, api.KindNew, []*gomatrixserverlib.HeaderedEvent{revokeEvent}, "test", "test", "test", nil, false); err != nil {
t.Errorf("failed to send events: %v", err)
}
// TODO: Even though we are sending the events sync, the "kickUsers" function is sending the events async, so we need
// to loop and wait for the events to be processed by the roomserver.
for i := 0; i <= 20; i++ {
// Get the membership events AFTER revoking guest access
membershipRes2 := &api.QueryMembershipsForRoomResponse{}
if err := rsAPI.QueryMembershipsForRoom(ctx, &api.QueryMembershipsForRoomRequest{LocalOnly: true, JoinedOnly: true, RoomID: room.ID}, membershipRes2); err != nil {
t.Errorf("failed to query membership for room: %s", err)
}
// The membership events should NOT match, as Bob (guest user) should now be kicked from the room
if !reflect.DeepEqual(membershipRes, membershipRes2) {
return
}
time.Sleep(time.Millisecond * 10)
}
t.Errorf("memberships didn't change in time")
}
func Test_QueryLeftUsers(t *testing.T) {
alice := test.NewUser(t)
bob := test.NewUser(t)
room := test.NewRoom(t, alice, test.RoomPreset(test.PresetTrustedPrivateChat))
@ -48,22 +184,42 @@ func Test_SharedUsers(t *testing.T) {
t.Fatalf("failed to send events: %v", err)
}
// Query the shared users for Alice, there should only be Bob.
// This is used by the SyncAPI keychange consumer.
res := &api.QuerySharedUsersResponse{}
if err := rsAPI.QuerySharedUsers(ctx, &api.QuerySharedUsersRequest{UserID: alice.ID}, res); err != nil {
t.Fatalf("unable to query known users: %v", err)
}
if _, ok := res.UserIDsToCount[bob.ID]; !ok {
t.Fatalf("expected to find %s in shared users, but didn't: %+v", bob.ID, res.UserIDsToCount)
}
// Also verify that we get the expected result when specifying OtherUserIDs.
// This is used by the SyncAPI when getting device list changes.
if err := rsAPI.QuerySharedUsers(ctx, &api.QuerySharedUsersRequest{UserID: alice.ID, OtherUserIDs: []string{bob.ID}}, res); err != nil {
t.Fatalf("unable to query known users: %v", err)
}
if _, ok := res.UserIDsToCount[bob.ID]; !ok {
t.Fatalf("expected to find %s in shared users, but didn't: %+v", bob.ID, res.UserIDsToCount)
// Query the left users, there should only be "@idontexist:test",
// as Alice and Bob are still joined.
res := &api.QueryLeftUsersResponse{}
leftUserID := "@idontexist:test"
getLeftUsersList := []string{alice.ID, bob.ID, leftUserID}
testCase := func(rsAPI api.RoomserverInternalAPI) {
if err := rsAPI.QueryLeftUsers(ctx, &api.QueryLeftUsersRequest{StaleDeviceListUsers: getLeftUsersList}, res); err != nil {
t.Fatalf("unable to query left users: %v", err)
}
wantCount := 1
if count := len(res.LeftUsers); count > wantCount {
t.Fatalf("unexpected left users count: want %d, got %d", wantCount, count)
}
if res.LeftUsers[0] != leftUserID {
t.Fatalf("unexpected left users : want %s, got %s", leftUserID, res.LeftUsers[0])
}
}
t.Run("HTTP API", func(t *testing.T) {
router := mux.NewRouter().PathPrefix(httputil.InternalPathPrefix).Subrouter()
roomserver.AddInternalRoutes(router, rsAPI, false)
apiURL, cancel := test.ListenAndServe(t, router, false)
defer cancel()
httpAPI, err := inthttp.NewRoomserverClient(apiURL, &http.Client{Timeout: time.Second * 5}, nil)
if err != nil {
t.Fatalf("failed to create HTTP client")
}
testCase(httpAPI)
})
t.Run("Monolith", func(t *testing.T) {
testCase(rsAPI)
// also test tracing
traceAPI := &api.RoomserverInternalAPITrace{Impl: rsAPI}
testCase(traceAPI)
})
})
}

View file

@ -172,5 +172,6 @@ type Database interface {
ForgetRoom(ctx context.Context, userID, roomID string, forget bool) error
GetHistoryVisibilityState(ctx context.Context, roomInfo *types.RoomInfo, eventID string, domain string) ([]*gomatrixserverlib.Event, error)
GetLeftUsers(ctx context.Context, userIDs []string) ([]string, error)
UpgradeRoom(ctx context.Context, oldRoomID, newRoomID, eventSender string) error
}

View file

@ -21,12 +21,13 @@ import (
"fmt"
"github.com/lib/pq"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/postgres/deltas"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
)
const membershipSchema = `
@ -157,6 +158,12 @@ const selectServerInRoomSQL = "" +
" JOIN roomserver_event_state_keys ON roomserver_membership.target_nid = roomserver_event_state_keys.event_state_key_nid" +
" WHERE membership_nid = $1 AND room_nid = $2 AND event_state_key LIKE '%:' || $3 LIMIT 1"
const selectJoinedUsersSQL = `
SELECT DISTINCT target_nid
FROM roomserver_membership m
WHERE membership_nid > $1 AND target_nid = ANY($2)
`
type membershipStatements struct {
insertMembershipStmt *sql.Stmt
selectMembershipForUpdateStmt *sql.Stmt
@ -174,6 +181,7 @@ type membershipStatements struct {
selectLocalServerInRoomStmt *sql.Stmt
selectServerInRoomStmt *sql.Stmt
deleteMembershipStmt *sql.Stmt
selectJoinedUsersStmt *sql.Stmt
}
func CreateMembershipTable(db *sql.DB) error {
@ -209,9 +217,33 @@ func PrepareMembershipTable(db *sql.DB) (tables.Membership, error) {
{&s.selectLocalServerInRoomStmt, selectLocalServerInRoomSQL},
{&s.selectServerInRoomStmt, selectServerInRoomSQL},
{&s.deleteMembershipStmt, deleteMembershipSQL},
{&s.selectJoinedUsersStmt, selectJoinedUsersSQL},
}.Prepare(db)
}
func (s *membershipStatements) SelectJoinedUsers(
ctx context.Context, txn *sql.Tx,
targetUserNIDs []types.EventStateKeyNID,
) ([]types.EventStateKeyNID, error) {
result := make([]types.EventStateKeyNID, 0, len(targetUserNIDs))
stmt := sqlutil.TxStmt(txn, s.selectJoinedUsersStmt)
rows, err := stmt.QueryContext(ctx, tables.MembershipStateLeaveOrBan, pq.Array(targetUserNIDs))
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "SelectJoinedUsers: rows.close() failed")
var targetNID types.EventStateKeyNID
for rows.Next() {
if err = rows.Scan(&targetNID); err != nil {
return nil, err
}
result = append(result, targetNID)
}
return result, rows.Err()
}
func (s *membershipStatements) InsertMembership(
ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,

View file

@ -1365,6 +1365,43 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs [
return result, nil
}
// GetLeftUsers calculates users we (the server) don't share a room with anymore.
func (d *Database) GetLeftUsers(ctx context.Context, userIDs []string) ([]string, error) {
// Get the userNID for all users with a stale device list
stateKeyNIDMap, err := d.EventStateKeyNIDs(ctx, userIDs)
if err != nil {
return nil, err
}
userNIDs := make([]types.EventStateKeyNID, 0, len(stateKeyNIDMap))
userNIDtoUserID := make(map[types.EventStateKeyNID]string, len(stateKeyNIDMap))
// Create a map from userNID -> userID
for userID, nid := range stateKeyNIDMap {
userNIDs = append(userNIDs, nid)
userNIDtoUserID[nid] = userID
}
// Get all users whose membership is still join, knock or invite.
stillJoinedUsersNIDs, err := d.MembershipTable.SelectJoinedUsers(ctx, nil, userNIDs)
if err != nil {
return nil, err
}
// Remove joined users from the "user with stale devices" list, which contains left AND joined users
for _, joinedUser := range stillJoinedUsersNIDs {
delete(userNIDtoUserID, joinedUser)
}
// The users still in our userNIDtoUserID map are the users we don't share a room with anymore,
// and the return value we are looking for.
leftUsers := make([]string, 0, len(userNIDtoUserID))
for _, userID := range userNIDtoUserID {
leftUsers = append(leftUsers, userID)
}
return leftUsers, nil
}
// GetLocalServerInRoom returns true if we think we're in a given room or false otherwise.
func (d *Database) GetLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) {
return d.MembershipTable.SelectLocalServerInRoom(ctx, nil, roomNID)

View file

@ -0,0 +1,96 @@
package shared_test
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/postgres"
"github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/storage/sqlite3"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig"
)
func mustCreateRoomserverDatabase(t *testing.T, dbType test.DBType) (*shared.Database, func()) {
t.Helper()
connStr, clearDB := test.PrepareDBConnectionString(t, dbType)
base, _, _ := testrig.Base(nil)
dbOpts := &config.DatabaseOptions{ConnectionString: config.DataSource(connStr)}
db, err := sqlutil.Open(dbOpts, sqlutil.NewExclusiveWriter())
assert.NoError(t, err)
var membershipTable tables.Membership
var stateKeyTable tables.EventStateKeys
switch dbType {
case test.DBTypePostgres:
err = postgres.CreateEventStateKeysTable(db)
assert.NoError(t, err)
err = postgres.CreateMembershipTable(db)
assert.NoError(t, err)
membershipTable, err = postgres.PrepareMembershipTable(db)
assert.NoError(t, err)
stateKeyTable, err = postgres.PrepareEventStateKeysTable(db)
case test.DBTypeSQLite:
err = sqlite3.CreateEventStateKeysTable(db)
assert.NoError(t, err)
err = sqlite3.CreateMembershipTable(db)
assert.NoError(t, err)
membershipTable, err = sqlite3.PrepareMembershipTable(db)
assert.NoError(t, err)
stateKeyTable, err = sqlite3.PrepareEventStateKeysTable(db)
}
assert.NoError(t, err)
return &shared.Database{
DB: db,
EventStateKeysTable: stateKeyTable,
MembershipTable: membershipTable,
Writer: sqlutil.NewExclusiveWriter(),
}, func() {
err := base.Close()
assert.NoError(t, err)
clearDB()
err = db.Close()
assert.NoError(t, err)
}
}
func Test_GetLeftUsers(t *testing.T) {
alice := test.NewUser(t)
bob := test.NewUser(t)
charlie := test.NewUser(t)
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateRoomserverDatabase(t, dbType)
defer close()
// Create dummy entries
for _, user := range []*test.User{alice, bob, charlie} {
nid, err := db.EventStateKeysTable.InsertEventStateKeyNID(ctx, nil, user.ID)
assert.NoError(t, err)
err = db.MembershipTable.InsertMembership(ctx, nil, 1, nid, true)
assert.NoError(t, err)
// We must update the membership with a non-zero event NID or it will get filtered out in later queries
membershipNID := tables.MembershipStateLeaveOrBan
if user == alice {
membershipNID = tables.MembershipStateJoin
}
_, err = db.MembershipTable.UpdateMembership(ctx, nil, 1, nid, nid, membershipNID, 1, false)
assert.NoError(t, err)
}
// Now try to get the left users, this should be Bob and Charlie, since they have a "leave" membership
expectedUserIDs := []string{bob.ID, charlie.ID}
leftUsers, err := db.GetLeftUsers(context.Background(), []string{alice.ID, bob.ID, charlie.ID})
assert.NoError(t, err)
assert.ElementsMatch(t, expectedUserIDs, leftUsers)
})
}

View file

@ -21,12 +21,13 @@ import (
"fmt"
"strings"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/sqlite3/deltas"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
)
const membershipSchema = `
@ -133,6 +134,12 @@ const selectServerInRoomSQL = "" +
const deleteMembershipSQL = "" +
"DELETE FROM roomserver_membership WHERE room_nid = $1 AND target_nid = $2"
const selectJoinedUsersSQL = `
SELECT DISTINCT target_nid
FROM roomserver_membership m
WHERE membership_nid > $1 AND target_nid IN ($2)
`
type membershipStatements struct {
db *sql.DB
insertMembershipStmt *sql.Stmt
@ -149,6 +156,7 @@ type membershipStatements struct {
selectLocalServerInRoomStmt *sql.Stmt
selectServerInRoomStmt *sql.Stmt
deleteMembershipStmt *sql.Stmt
// selectJoinedUsersStmt *sql.Stmt // Prepared at runtime
}
func CreateMembershipTable(db *sql.DB) error {
@ -412,3 +420,40 @@ func (s *membershipStatements) DeleteMembership(
)
return err
}
func (s *membershipStatements) SelectJoinedUsers(
ctx context.Context, txn *sql.Tx,
targetUserNIDs []types.EventStateKeyNID,
) ([]types.EventStateKeyNID, error) {
result := make([]types.EventStateKeyNID, 0, len(targetUserNIDs))
qry := strings.Replace(selectJoinedUsersSQL, "($2)", sqlutil.QueryVariadicOffset(len(targetUserNIDs), 1), 1)
stmt, err := s.db.Prepare(qry)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, stmt, "SelectJoinedUsers: stmt.Close failed")
params := make([]any, len(targetUserNIDs)+1)
params[0] = tables.MembershipStateLeaveOrBan
for i := range targetUserNIDs {
params[i+1] = targetUserNIDs[i]
}
stmt = sqlutil.TxStmt(txn, stmt)
rows, err := stmt.QueryContext(ctx, params...)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "SelectJoinedUsers: rows.close() failed")
var targetNID types.EventStateKeyNID
for rows.Next() {
if err = rows.Scan(&targetNID); err != nil {
return nil, err
}
result = append(result, targetNID)
}
return result, rows.Err()
}

View file

@ -144,6 +144,7 @@ type Membership interface {
SelectLocalServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) (bool, error)
SelectServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error)
DeleteMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) error
SelectJoinedUsers(ctx context.Context, txn *sql.Tx, targetUserNIDs []types.EventStateKeyNID) ([]types.EventStateKeyNID, error)
}
type Published interface {

View file

@ -129,5 +129,11 @@ func TestMembershipTable(t *testing.T) {
knownUsers, err := tab.SelectKnownUsers(ctx, nil, userNIDs[0], "localhost", 2)
assert.NoError(t, err)
assert.Equal(t, 1, len(knownUsers))
// get users we share a room with, given their userNID
joinedUsers, err := tab.SelectJoinedUsers(ctx, nil, userNIDs)
assert.NoError(t, err)
// Only userNIDs[0] is actually joined, so we only expect this userNID
assert.Equal(t, userNIDs[:1], joinedUsers)
})
}

View file

@ -174,7 +174,7 @@ func (c *Global) SigningIdentityFor(serverName gomatrixserverlib.ServerName) (*g
return id, nil
}
}
return nil, fmt.Errorf("no signing identity %q", serverName)
return nil, fmt.Errorf("no signing identity for %q", serverName)
}
func (c *Global) SigningIdentities() []*gomatrixserverlib.SigningIdentity {

View file

@ -16,8 +16,10 @@ package config
import (
"fmt"
"reflect"
"testing"
"github.com/matrix-org/gomatrixserverlib"
"gopkg.in/yaml.v2"
)
@ -290,3 +292,55 @@ func TestUnmarshalDataUnit(t *testing.T) {
}
}
}
func Test_SigningIdentityFor(t *testing.T) {
tests := []struct {
name string
virtualHosts []*VirtualHost
serverName gomatrixserverlib.ServerName
want *gomatrixserverlib.SigningIdentity
wantErr bool
}{
{
name: "no virtual hosts defined",
wantErr: true,
},
{
name: "no identity found",
serverName: gomatrixserverlib.ServerName("doesnotexist"),
wantErr: true,
},
{
name: "found identity",
serverName: gomatrixserverlib.ServerName("main"),
want: &gomatrixserverlib.SigningIdentity{ServerName: "main"},
},
{
name: "identity found on virtual hosts",
serverName: gomatrixserverlib.ServerName("vh2"),
virtualHosts: []*VirtualHost{
{SigningIdentity: gomatrixserverlib.SigningIdentity{ServerName: "vh1"}},
{SigningIdentity: gomatrixserverlib.SigningIdentity{ServerName: "vh2"}},
},
want: &gomatrixserverlib.SigningIdentity{ServerName: "vh2"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Global{
VirtualHosts: tt.virtualHosts,
SigningIdentity: gomatrixserverlib.SigningIdentity{
ServerName: "main",
},
}
got, err := c.SigningIdentityFor(tt.serverName)
if (err != nil) != tt.wantErr {
t.Errorf("SigningIdentityFor() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("SigningIdentityFor() got = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -78,7 +78,7 @@ func (s *PresenceConsumer) Start() error {
// Normal NATS subscription, used by Request/Reply
_, err := s.nats.Subscribe(s.requestTopic, func(msg *nats.Msg) {
userID := msg.Header.Get(jetstream.UserID)
presence, err := s.db.GetPresence(context.Background(), userID)
presences, err := s.db.GetPresences(context.Background(), []string{userID})
m := &nats.Msg{
Header: nats.Header{},
}
@ -89,10 +89,12 @@ func (s *PresenceConsumer) Start() error {
}
return
}
if presence == nil {
presence = &types.PresenceInternal{
UserID: userID,
}
presence := &types.PresenceInternal{
UserID: userID,
}
if len(presences) > 0 {
presence = presences[0]
}
deviceRes := api.QueryDevicesResponse{}

View file

@ -106,7 +106,7 @@ type DatabaseTransaction interface {
SelectMembershipForUser(ctx context.Context, roomID, userID string, pos int64) (membership string, topologicalPos int, err error)
// getUserUnreadNotificationCountsForRooms returns the unread notifications for the given rooms
GetUserUnreadNotificationCountsForRooms(ctx context.Context, userID string, roomIDs map[string]string) (map[string]*eventutil.NotificationData, error)
GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error)
GetPresences(ctx context.Context, userID []string) ([]*types.PresenceInternal, error)
PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error)
RelationsFor(ctx context.Context, roomID, eventID, relType, eventType string, from, to types.StreamPosition, backwards bool, limit int) (events []types.StreamEvent, prevBatch, nextBatch string, err error)
}
@ -186,7 +186,7 @@ type Database interface {
}
type Presence interface {
GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error)
GetPresences(ctx context.Context, userIDs []string) ([]*types.PresenceInternal, error)
UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (types.StreamPosition, error)
}

View file

@ -19,10 +19,12 @@ import (
"database/sql"
"time"
"github.com/lib/pq"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
)
const presenceSchema = `
@ -63,9 +65,9 @@ const upsertPresenceFromSyncSQL = "" +
" RETURNING id"
const selectPresenceForUserSQL = "" +
"SELECT presence, status_msg, last_active_ts" +
"SELECT user_id, presence, status_msg, last_active_ts" +
" FROM syncapi_presence" +
" WHERE user_id = $1 LIMIT 1"
" WHERE user_id = ANY($1)"
const selectMaxPresenceSQL = "" +
"SELECT COALESCE(MAX(id), 0) FROM syncapi_presence"
@ -119,20 +121,28 @@ func (p *presenceStatements) UpsertPresence(
return
}
// GetPresenceForUser returns the current presence of a user.
func (p *presenceStatements) GetPresenceForUser(
// GetPresenceForUsers returns the current presence for a list of users.
// If the user doesn't have a presence status yet, it is omitted from the response.
func (p *presenceStatements) GetPresenceForUsers(
ctx context.Context, txn *sql.Tx,
userID string,
) (*types.PresenceInternal, error) {
result := &types.PresenceInternal{
UserID: userID,
}
userIDs []string,
) ([]*types.PresenceInternal, error) {
result := make([]*types.PresenceInternal, 0, len(userIDs))
stmt := sqlutil.TxStmt(txn, p.selectPresenceForUsersStmt)
err := stmt.QueryRowContext(ctx, userID).Scan(&result.Presence, &result.ClientFields.StatusMsg, &result.LastActiveTS)
if err == sql.ErrNoRows {
return nil, nil
rows, err := stmt.QueryContext(ctx, pq.Array(userIDs))
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "GetPresenceForUsers: rows.close() failed")
for rows.Next() {
presence := &types.PresenceInternal{}
if err = rows.Scan(&presence.UserID, &presence.Presence, &presence.ClientFields.StatusMsg, &presence.LastActiveTS); err != nil {
return nil, err
}
presence.ClientFields.Presence = presence.Presence.String()
result = append(result, presence)
}
result.ClientFields.Presence = result.Presence.String()
return result, err
}

View file

@ -57,31 +57,23 @@ type Database struct {
}
func (d *Database) NewDatabaseSnapshot(ctx context.Context) (*DatabaseTransaction, error) {
return d.NewDatabaseTransaction(ctx)
/*
TODO: Repeatable read is probably the right thing to do here,
but it seems to cause some problems with the invite tests, so
need to investigate that further.
txn, err := d.DB.BeginTx(ctx, &sql.TxOptions{
// Set the isolation level so that we see a snapshot of the database.
// In PostgreSQL repeatable read transactions will see a snapshot taken
// at the first query, and since the transaction is read-only it can't
// run into any serialisation errors.
// https://www.postgresql.org/docs/9.5/static/transaction-iso.html#XACT-REPEATABLE-READ
Isolation: sql.LevelRepeatableRead,
ReadOnly: true,
})
if err != nil {
return nil, err
}
return &DatabaseTransaction{
Database: d,
ctx: ctx,
txn: txn,
}, nil
*/
txn, err := d.DB.BeginTx(ctx, &sql.TxOptions{
// Set the isolation level so that we see a snapshot of the database.
// In PostgreSQL repeatable read transactions will see a snapshot taken
// at the first query, and since the transaction is read-only it can't
// run into any serialisation errors.
// https://www.postgresql.org/docs/9.5/static/transaction-iso.html#XACT-REPEATABLE-READ
Isolation: sql.LevelRepeatableRead,
ReadOnly: true,
})
if err != nil {
return nil, err
}
return &DatabaseTransaction{
Database: d,
ctx: ctx,
txn: txn,
}, nil
}
func (d *Database) NewDatabaseTransaction(ctx context.Context) (*DatabaseTransaction, error) {
@ -572,8 +564,8 @@ func (d *Database) UpdatePresence(ctx context.Context, userID string, presence t
return pos, err
}
func (d *Database) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) {
return d.Presence.GetPresenceForUser(ctx, nil, userID)
func (d *Database) GetPresences(ctx context.Context, userIDs []string) ([]*types.PresenceInternal, error) {
return d.Presence.GetPresenceForUsers(ctx, nil, userIDs)
}
func (d *Database) SelectMembershipForUser(ctx context.Context, roomID, userID string, pos int64) (membership string, topologicalPos int, err error) {

View file

@ -596,8 +596,8 @@ func (d *DatabaseTransaction) GetUserUnreadNotificationCountsForRooms(ctx contex
return d.NotificationData.SelectUserUnreadCountsForRooms(ctx, d.txn, userID, roomIDs)
}
func (d *DatabaseTransaction) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) {
return d.Presence.GetPresenceForUser(ctx, d.txn, userID)
func (d *DatabaseTransaction) GetPresences(ctx context.Context, userIDs []string) ([]*types.PresenceInternal, error) {
return d.Presence.GetPresenceForUsers(ctx, d.txn, userIDs)
}
func (d *DatabaseTransaction) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) {

View file

@ -17,12 +17,14 @@ package sqlite3
import (
"context"
"database/sql"
"strings"
"time"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
)
const presenceSchema = `
@ -62,9 +64,9 @@ const upsertPresenceFromSyncSQL = "" +
" RETURNING id"
const selectPresenceForUserSQL = "" +
"SELECT presence, status_msg, last_active_ts" +
"SELECT user_id, presence, status_msg, last_active_ts" +
" FROM syncapi_presence" +
" WHERE user_id = $1 LIMIT 1"
" WHERE user_id IN ($1)"
const selectMaxPresenceSQL = "" +
"SELECT COALESCE(MAX(id), 0) FROM syncapi_presence"
@ -134,20 +136,38 @@ func (p *presenceStatements) UpsertPresence(
return
}
// GetPresenceForUser returns the current presence of a user.
func (p *presenceStatements) GetPresenceForUser(
// GetPresenceForUsers returns the current presence for a list of users.
// If the user doesn't have a presence status yet, it is omitted from the response.
func (p *presenceStatements) GetPresenceForUsers(
ctx context.Context, txn *sql.Tx,
userID string,
) (*types.PresenceInternal, error) {
result := &types.PresenceInternal{
UserID: userID,
userIDs []string,
) ([]*types.PresenceInternal, error) {
qry := strings.Replace(selectPresenceForUserSQL, "($1)", sqlutil.QueryVariadic(len(userIDs)), 1)
prepStmt, err := p.db.Prepare(qry)
if err != nil {
return nil, err
}
stmt := sqlutil.TxStmt(txn, p.selectPresenceForUsersStmt)
err := stmt.QueryRowContext(ctx, userID).Scan(&result.Presence, &result.ClientFields.StatusMsg, &result.LastActiveTS)
if err == sql.ErrNoRows {
return nil, nil
defer internal.CloseAndLogIfError(ctx, prepStmt, "GetPresenceForUsers: stmt.close() failed")
params := make([]interface{}, len(userIDs))
for i := range userIDs {
params[i] = userIDs[i]
}
rows, err := sqlutil.TxStmt(txn, prepStmt).QueryContext(ctx, params...)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "GetPresenceForUsers: rows.close() failed")
result := make([]*types.PresenceInternal, 0, len(userIDs))
for rows.Next() {
presence := &types.PresenceInternal{}
if err = rows.Scan(&presence.UserID, &presence.Presence, &presence.ClientFields.StatusMsg, &presence.LastActiveTS); err != nil {
return nil, err
}
presence.ClientFields.Presence = presence.Presence.String()
result = append(result, presence)
}
result.ClientFields.Presence = result.Presence.String()
return result, err
}

View file

@ -207,7 +207,7 @@ type Ignores interface {
type Presence interface {
UpsertPresence(ctx context.Context, txn *sql.Tx, userID string, statusMsg *string, presence types.Presence, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (pos types.StreamPosition, err error)
GetPresenceForUser(ctx context.Context, txn *sql.Tx, userID string) (presence *types.PresenceInternal, err error)
GetPresenceForUsers(ctx context.Context, txn *sql.Tx, userIDs []string) (presence []*types.PresenceInternal, err error)
GetMaxPresenceID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error)
GetPresenceAfter(ctx context.Context, txn *sql.Tx, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (presences map[string]*types.PresenceInternal, err error)
}

View file

@ -0,0 +1,136 @@
package tables_test
import (
"context"
"database/sql"
"reflect"
"testing"
"time"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/syncapi/storage/postgres"
"github.com/matrix-org/dendrite/syncapi/storage/sqlite3"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/dendrite/test"
)
func mustPresenceTable(t *testing.T, dbType test.DBType) (tables.Presence, func()) {
t.Helper()
connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := sqlutil.Open(&config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
}, sqlutil.NewExclusiveWriter())
if err != nil {
t.Fatalf("failed to open db: %s", err)
}
var tab tables.Presence
switch dbType {
case test.DBTypePostgres:
tab, err = postgres.NewPostgresPresenceTable(db)
case test.DBTypeSQLite:
var stream sqlite3.StreamIDStatements
if err = stream.Prepare(db); err != nil {
t.Fatalf("failed to prepare stream stmts: %s", err)
}
tab, err = sqlite3.NewSqlitePresenceTable(db, &stream)
}
if err != nil {
t.Fatalf("failed to make new table: %s", err)
}
return tab, close
}
func TestPresence(t *testing.T) {
alice := test.NewUser(t)
bob := test.NewUser(t)
ctx := context.Background()
statusMsg := "Hello World!"
timestamp := gomatrixserverlib.AsTimestamp(time.Now())
var txn *sql.Tx
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
tab, closeDB := mustPresenceTable(t, dbType)
defer closeDB()
// Insert some presences
pos, err := tab.UpsertPresence(ctx, txn, alice.ID, &statusMsg, types.PresenceOnline, timestamp, false)
if err != nil {
t.Error(err)
}
wantPos := types.StreamPosition(1)
if pos != wantPos {
t.Errorf("expected pos to be %d, got %d", wantPos, pos)
}
pos, err = tab.UpsertPresence(ctx, txn, bob.ID, &statusMsg, types.PresenceOnline, timestamp, false)
if err != nil {
t.Error(err)
}
wantPos = 2
if pos != wantPos {
t.Errorf("expected pos to be %d, got %d", wantPos, pos)
}
// verify the expected max presence ID
maxPos, err := tab.GetMaxPresenceID(ctx, txn)
if err != nil {
t.Error(err)
}
if maxPos != wantPos {
t.Errorf("expected max pos to be %d, got %d", wantPos, maxPos)
}
// This should increment the position
pos, err = tab.UpsertPresence(ctx, txn, bob.ID, &statusMsg, types.PresenceOnline, timestamp, true)
if err != nil {
t.Error(err)
}
wantPos = pos
if wantPos <= maxPos {
t.Errorf("expected pos to be %d incremented, got %d", wantPos, pos)
}
// This should return only Bobs status
presences, err := tab.GetPresenceAfter(ctx, txn, maxPos, gomatrixserverlib.EventFilter{Limit: 10})
if err != nil {
t.Error(err)
}
if c := len(presences); c > 1 {
t.Errorf("expected only one presence, got %d", c)
}
// Validate the response
wantPresence := &types.PresenceInternal{
UserID: bob.ID,
Presence: types.PresenceOnline,
StreamPos: wantPos,
LastActiveTS: timestamp,
ClientFields: types.PresenceClientResponse{
LastActiveAgo: 0,
Presence: types.PresenceOnline.String(),
StatusMsg: &statusMsg,
},
}
if !reflect.DeepEqual(wantPresence, presences[bob.ID]) {
t.Errorf("unexpected presence result:\n%+v, want\n%+v", presences[bob.ID], wantPresence)
}
// Try getting presences for existing and non-existing users
getUsers := []string{alice.ID, bob.ID, "@doesntexist:test"}
presencesForUsers, err := tab.GetPresenceForUsers(ctx, nil, getUsers)
if err != nil {
t.Error(err)
}
if len(presencesForUsers) >= len(getUsers) {
t.Errorf("expected less presences, but they are the same/more as requested: %d >= %d", len(presencesForUsers), len(getUsers))
}
})
}

View file

@ -17,6 +17,7 @@ package streams
import (
"context"
"encoding/json"
"fmt"
"sync"
"github.com/matrix-org/gomatrixserverlib"
@ -70,39 +71,25 @@ func (p *PresenceStreamProvider) IncrementalSync(
return from
}
if len(presences) == 0 {
getPresenceForUsers, err := p.getNeededUsersFromRequest(ctx, req, presences)
if err != nil {
req.Log.WithError(err).Error("getNeededUsersFromRequest failed")
return from
}
// Got no presence between range and no presence to get from the database
if len(getPresenceForUsers) == 0 && len(presences) == 0 {
return to
}
// add newly joined rooms user presences
newlyJoined := joinedRooms(req.Response, req.Device.UserID)
if len(newlyJoined) > 0 {
// TODO: Check if this is working better than before.
if err = p.notifier.LoadRooms(ctx, p.DB, newlyJoined); err != nil {
req.Log.WithError(err).Error("unable to refresh notifier lists")
return from
}
NewlyJoinedLoop:
for _, roomID := range newlyJoined {
roomUsers := p.notifier.JoinedUsers(roomID)
for i := range roomUsers {
// we already got a presence from this user
if _, ok := presences[roomUsers[i]]; ok {
continue
}
// Bear in mind that this might return nil, but at least populating
// a nil means that there's a map entry so we won't repeat this call.
presences[roomUsers[i]], err = snapshot.GetPresence(ctx, roomUsers[i])
if err != nil {
req.Log.WithError(err).Error("unable to query presence for user")
_ = snapshot.Rollback()
return from
}
if len(presences) > req.Filter.Presence.Limit {
break NewlyJoinedLoop
}
}
}
dbPresences, err := snapshot.GetPresences(ctx, getPresenceForUsers)
if err != nil {
req.Log.WithError(err).Error("unable to query presence for user")
_ = snapshot.Rollback()
return from
}
for _, presence := range dbPresences {
presences[presence.UserID] = presence
}
lastPos := from
@ -164,6 +151,39 @@ func (p *PresenceStreamProvider) IncrementalSync(
return lastPos
}
func (p *PresenceStreamProvider) getNeededUsersFromRequest(ctx context.Context, req *types.SyncRequest, presences map[string]*types.PresenceInternal) ([]string, error) {
getPresenceForUsers := []string{}
// Add presence for users which newly joined a room
for userID := range req.MembershipChanges {
if _, ok := presences[userID]; ok {
continue
}
getPresenceForUsers = append(getPresenceForUsers, userID)
}
// add newly joined rooms user presences
newlyJoined := joinedRooms(req.Response, req.Device.UserID)
if len(newlyJoined) == 0 {
return getPresenceForUsers, nil
}
// TODO: Check if this is working better than before.
if err := p.notifier.LoadRooms(ctx, p.DB, newlyJoined); err != nil {
return getPresenceForUsers, fmt.Errorf("unable to refresh notifier lists: %w", err)
}
for _, roomID := range newlyJoined {
roomUsers := p.notifier.JoinedUsers(roomID)
for i := range roomUsers {
// we already got a presence from this user
if _, ok := presences[roomUsers[i]]; ok {
continue
}
getPresenceForUsers = append(getPresenceForUsers, roomUsers[i])
}
}
return getPresenceForUsers, nil
}
func joinedRooms(res *types.Response, userID string) []string {
var roomIDs []string
for roomID, join := range res.Rooms.Join {

View file

@ -145,12 +145,12 @@ func (rp *RequestPool) updatePresence(db storage.Presence, presence string, user
}
// ensure we also send the current status_msg to federated servers and not nil
dbPresence, err := db.GetPresence(context.Background(), userID)
dbPresence, err := db.GetPresences(context.Background(), []string{userID})
if err != nil && err != sql.ErrNoRows {
return
}
if dbPresence != nil {
newPresence.ClientFields = dbPresence.ClientFields
if len(dbPresence) > 0 && dbPresence[0] != nil {
newPresence.ClientFields = dbPresence[0].ClientFields
}
newPresence.ClientFields.Presence = presenceID.String()

View file

@ -29,8 +29,8 @@ func (d dummyDB) UpdatePresence(ctx context.Context, userID string, presence typ
return 0, nil
}
func (d dummyDB) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) {
return &types.PresenceInternal{}, nil
func (d dummyDB) GetPresences(ctx context.Context, userID []string) ([]*types.PresenceInternal, error) {
return []*types.PresenceInternal{}, nil
}
func (d dummyDB) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) {

View file

@ -48,4 +48,7 @@ If a device list update goes missing, the server resyncs on the next one
Leaves are present in non-gapped incremental syncs
# Below test was passing for the wrong reason, failing correctly since #2858
New federated private chats get full presence information (SYN-115)
New federated private chats get full presence information (SYN-115)
# We don't have any state to calculate m.room.guest_access when accepting invites
Guest users can accept invites to private rooms over federation

View file

@ -763,4 +763,7 @@ AS and main public room lists are separate
local user has tags copied to the new room
remote user has tags copied to the new room
/upgrade moves remote aliases to the new room
Local and remote users' homeservers remove a room from their public directory on upgrade
Local and remote users' homeservers remove a room from their public directory on upgrade
Guest users denied access over federation if guest access prohibited
Guest users are kicked from guest_access rooms on revocation of guest_access
Guest users are kicked from guest_access rooms on revocation of guest_access over federation

View file

@ -38,11 +38,12 @@ var (
)
type Room struct {
ID string
Version gomatrixserverlib.RoomVersion
preset Preset
visibility gomatrixserverlib.HistoryVisibility
creator *User
ID string
Version gomatrixserverlib.RoomVersion
preset Preset
guestCanJoin bool
visibility gomatrixserverlib.HistoryVisibility
creator *User
authEvents gomatrixserverlib.AuthEvents
currentState map[string]*gomatrixserverlib.HeaderedEvent
@ -120,6 +121,11 @@ func (r *Room) insertCreateEvents(t *testing.T) {
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomPowerLevels, plContent, WithStateKey(""))
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomJoinRules, joinRule, WithStateKey(""))
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomHistoryVisibility, hisVis, WithStateKey(""))
if r.guestCanJoin {
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomGuestAccess, map[string]string{
"guest_access": "can_join",
}, WithStateKey(""))
}
}
// Create an event in this room but do not insert it. Does not modify the room in any way (depth, fwd extremities, etc) so is thread-safe.
@ -268,3 +274,9 @@ func RoomVersion(ver gomatrixserverlib.RoomVersion) roomModifier {
r.Version = ver
}
}
func GuestsCanJoin(canJoin bool) roomModifier {
return func(t *testing.T, r *Room) {
r.guestCanJoin = canJoin
}
}

View file

@ -108,7 +108,7 @@ func Base(cfg *config.Dendrite) (*base.BaseDendrite, nats.JetStreamContext, *nat
cfg.Global.JetStream.InMemory = true
cfg.SyncAPI.Fulltext.InMemory = true
cfg.FederationAPI.KeyPerspectives = nil
base := base.NewBaseDendrite(cfg, "Tests")
base := base.NewBaseDendrite(cfg, "Tests", base.DisableMetrics)
js, jc := base.NATS.Prepare(base.ProcessContext, &cfg.Global.JetStream)
return base, js, jc
}

View file

@ -47,7 +47,7 @@ var (
type User struct {
ID string
accountType api.AccountType
AccountType api.AccountType
// key ID and private key of the server who has this user, if known.
keyID gomatrixserverlib.KeyID
privKey ed25519.PrivateKey
@ -66,7 +66,7 @@ func WithSigningServer(srvName gomatrixserverlib.ServerName, keyID gomatrixserve
func WithAccountType(accountType api.AccountType) UserOpt {
return func(u *User) {
u.accountType = accountType
u.AccountType = accountType
}
}

View file

@ -50,6 +50,7 @@ type KeyserverUserAPI interface {
type RoomserverUserAPI interface {
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) (err error)
}
// api functions required by the media api
@ -671,3 +672,12 @@ type PerformSaveThreePIDAssociationRequest struct {
ServerName gomatrixserverlib.ServerName
Medium string
}
type QueryAccountByLocalpartRequest struct {
Localpart string
ServerName gomatrixserverlib.ServerName
}
type QueryAccountByLocalpartResponse struct {
Account *Account
}

View file

@ -204,6 +204,12 @@ func (t *UserInternalAPITrace) PerformSaveThreePIDAssociation(ctx context.Contex
return err
}
func (t *UserInternalAPITrace) QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) error {
err := t.Impl.QueryAccountByLocalpart(ctx, req, res)
util.GetLogger(ctx).Infof("QueryAccountByLocalpart req=%+v res=%+v", js(req), js(res))
return err
}
func js(thing interface{}) string {
b, err := json.Marshal(thing)
if err != nil {

View file

@ -548,6 +548,11 @@ func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAc
return nil
}
func (a *UserInternalAPI) QueryAccountByLocalpart(ctx context.Context, req *api.QueryAccountByLocalpartRequest, res *api.QueryAccountByLocalpartResponse) (err error) {
res.Account, err = a.DB.GetAccountByLocalpart(ctx, req.Localpart, req.ServerName)
return
}
// Return the appservice 'device' or nil if the token is not an appservice. Returns an error if there was a problem
// creating a 'device'.
func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appServiceUserID string) (*api.Device, error) {

View file

@ -60,6 +60,7 @@ const (
QueryAccountByPasswordPath = "/userapi/queryAccountByPassword"
QueryLocalpartForThreePIDPath = "/userapi/queryLocalpartForThreePID"
QueryThreePIDsForLocalpartPath = "/userapi/queryThreePIDsForLocalpart"
QueryAccountByLocalpartPath = "/userapi/queryAccountType"
)
// NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API.
@ -440,3 +441,14 @@ func (h *httpUserInternalAPI) PerformSaveThreePIDAssociation(
h.httpClient, ctx, request, response,
)
}
func (h *httpUserInternalAPI) QueryAccountByLocalpart(
ctx context.Context,
req *api.QueryAccountByLocalpartRequest,
res *api.QueryAccountByLocalpartResponse,
) error {
return httputil.CallInternalRPCAPI(
"QueryAccountByLocalpart", h.apiURL+QueryAccountByLocalpartPath,
h.httpClient, ctx, req, res,
)
}

View file

@ -189,4 +189,9 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI, enableMetrics
PerformSaveThreePIDAssociationPath,
httputil.MakeInternalRPCAPI("UserAPIPerformSaveThreePIDAssociation", enableMetrics, s.PerformSaveThreePIDAssociation),
)
internalAPIMux.Handle(
QueryAccountByLocalpartPath,
httputil.MakeInternalRPCAPI("AccountByLocalpart", enableMetrics, s.QueryAccountByLocalpart),
)
}

View file

@ -307,3 +307,64 @@ func TestLoginToken(t *testing.T) {
})
})
}
func TestQueryAccountByLocalpart(t *testing.T) {
alice := test.NewUser(t)
localpart, userServername, _ := gomatrixserverlib.SplitID('@', alice.ID)
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
intAPI, db, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType)
defer close()
createdAcc, err := db.CreateAccount(ctx, localpart, userServername, "", "", alice.AccountType)
if err != nil {
t.Error(err)
}
testCases := func(t *testing.T, internalAPI api.UserInternalAPI) {
// Query existing account
queryAccResp := &api.QueryAccountByLocalpartResponse{}
if err = internalAPI.QueryAccountByLocalpart(ctx, &api.QueryAccountByLocalpartRequest{
Localpart: localpart,
ServerName: userServername,
}, queryAccResp); err != nil {
t.Error(err)
}
if !reflect.DeepEqual(createdAcc, queryAccResp.Account) {
t.Fatalf("created and queried accounts don't match:\n%+v vs.\n%+v", createdAcc, queryAccResp.Account)
}
// Query non-existent account, this should result in an error
err = internalAPI.QueryAccountByLocalpart(ctx, &api.QueryAccountByLocalpartRequest{
Localpart: "doesnotexist",
ServerName: userServername,
}, queryAccResp)
if err == nil {
t.Fatalf("expected an error, but got none: %+v", queryAccResp)
}
}
t.Run("Monolith", func(t *testing.T) {
testCases(t, intAPI)
// also test tracing
testCases(t, &api.UserInternalAPITrace{Impl: intAPI})
})
t.Run("HTTP API", func(t *testing.T) {
router := mux.NewRouter().PathPrefix(httputil.InternalPathPrefix).Subrouter()
userapi.AddInternalRoutes(router, intAPI, false)
apiURL, cancel := test.ListenAndServe(t, router, false)
defer cancel()
userHTTPApi, err := inthttp.NewUserAPIClient(apiURL, &http.Client{Timeout: time.Second * 5})
if err != nil {
t.Fatalf("failed to create HTTP client: %s", err)
}
testCases(t, userHTTPApi)
})
})
}

119
userapi/util/notify_test.go Normal file
View file

@ -0,0 +1,119 @@
package util_test
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"golang.org/x/crypto/bcrypt"
"github.com/matrix-org/dendrite/internal/pushgateway"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage"
userUtil "github.com/matrix-org/dendrite/userapi/util"
)
func TestNotifyUserCountsAsync(t *testing.T) {
alice := test.NewUser(t)
aliceLocalpart, serverName, err := gomatrixserverlib.SplitID('@', alice.ID)
if err != nil {
t.Error(err)
}
ctx := context.Background()
// Create a test room, just used to provide events
room := test.NewRoom(t, alice)
dummyEvent := room.Events()[len(room.Events())-1]
appID := util.RandomString(8)
pushKey := util.RandomString(8)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
receivedRequest := make(chan bool, 1)
// create a test server which responds to our /notify call
srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var data pushgateway.NotifyRequest
if err := json.NewDecoder(r.Body).Decode(&data); err != nil {
t.Error(err)
}
notification := data.Notification
// Validate the request
if notification.Counts == nil {
t.Fatal("no unread notification counts in request")
}
if unread := notification.Counts.Unread; unread != 1 {
t.Errorf("expected one unread notification, got %d", unread)
}
if len(notification.Devices) == 0 {
t.Fatal("expected devices in request")
}
// We only created one push device, so access it directly
device := notification.Devices[0]
if device.AppID != appID {
t.Errorf("unexpected app_id: %s, want %s", device.AppID, appID)
}
if device.PushKey != pushKey {
t.Errorf("unexpected push_key: %s, want %s", device.PushKey, pushKey)
}
// Return empty result, otherwise the call is handled as failed
if _, err := w.Write([]byte("{}")); err != nil {
t.Error(err)
}
close(receivedRequest)
}))
defer srv.Close()
// Create DB and Dendrite base
connStr, close := test.PrepareDBConnectionString(t, dbType)
defer close()
base, _, _ := testrig.Base(nil)
defer base.Close()
db, err := storage.NewUserAPIDatabase(base, &config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
}, "test", bcrypt.MinCost, 0, 0, "")
if err != nil {
t.Error(err)
}
// Prepare pusher with our test server URL
if err := db.UpsertPusher(ctx, api.Pusher{
Kind: api.HTTPKind,
AppID: appID,
PushKey: pushKey,
Data: map[string]interface{}{
"url": srv.URL,
},
}, aliceLocalpart, serverName); err != nil {
t.Error(err)
}
// Insert a dummy event
if err := db.InsertNotification(ctx, aliceLocalpart, serverName, dummyEvent.EventID(), 0, nil, &api.Notification{
Event: gomatrixserverlib.HeaderedToClientEvent(dummyEvent, gomatrixserverlib.FormatAll),
}); err != nil {
t.Error(err)
}
// Notify the user about a new notification
if err := userUtil.NotifyUserCountsAsync(ctx, pushgateway.NewHTTPClient(true), aliceLocalpart, serverName, db); err != nil {
t.Error(err)
}
select {
case <-time.After(time.Second * 5):
t.Error("timed out waiting for response")
case <-receivedRequest:
}
})
}

Some files were not shown because too many files have changed in this diff Show more