mirror of
https://github.com/matrix-org/dendrite.git
synced 2026-01-17 02:53:11 -06:00
Merge branch 'main' of github.com:matrix-org/dendrite into s7evink/complementcoverage
This commit is contained in:
commit
dbbbd2985d
3
.github/workflows/dendrite.yml
vendored
3
.github/workflows/dendrite.yml
vendored
|
|
@ -331,7 +331,8 @@ jobs:
|
|||
postgres: postgres
|
||||
api: full-http
|
||||
container:
|
||||
image: matrixdotorg/sytest-dendrite:latest
|
||||
# Temporary for debugging to see if this image is working better.
|
||||
image: matrixdotorg/sytest-dendrite@sha256:434ad464a9f4ed3f8c3cc47200275b6ccb5c5031a8063daf4acea62be5a23c73
|
||||
volumes:
|
||||
- ${{ github.workspace }}:/src
|
||||
- /root/.cache/go-build:/github/home/.cache/go-build
|
||||
|
|
|
|||
|
|
@ -180,14 +180,14 @@ func startup() {
|
|||
base := base.NewBaseDendrite(cfg, "Monolith")
|
||||
defer base.Close() // nolint: errcheck
|
||||
|
||||
rsAPI := roomserver.NewInternalAPI(base)
|
||||
|
||||
federation := conn.CreateFederationClient(base, pSessions)
|
||||
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation)
|
||||
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation, rsAPI)
|
||||
|
||||
serverKeyAPI := &signing.YggdrasilKeys{}
|
||||
keyRing := serverKeyAPI.KeyRing()
|
||||
|
||||
rsAPI := roomserver.NewInternalAPI(base)
|
||||
|
||||
userAPI := userapi.NewInternalAPI(base, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient())
|
||||
keyAPI.SetUserAPI(userAPI)
|
||||
|
||||
|
|
|
|||
|
|
@ -350,7 +350,7 @@ func (m *DendriteMonolith) Start() {
|
|||
base, federation, rsAPI, base.Caches, keyRing, true,
|
||||
)
|
||||
|
||||
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI)
|
||||
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI, rsAPI)
|
||||
m.userAPI = userapi.NewInternalAPI(base, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI, rsAPI, base.PushGatewayHTTPClient())
|
||||
keyAPI.SetUserAPI(m.userAPI)
|
||||
|
||||
|
|
|
|||
|
|
@ -165,7 +165,7 @@ func (m *DendriteMonolith) Start() {
|
|||
base, federation, rsAPI, base.Caches, keyRing, true,
|
||||
)
|
||||
|
||||
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation)
|
||||
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation, rsAPI)
|
||||
userAPI := userapi.NewInternalAPI(base, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI, rsAPI, base.PushGatewayHTTPClient())
|
||||
keyAPI.SetUserAPI(userAPI)
|
||||
|
||||
|
|
|
|||
134
clientapi/admin_test.go
Normal file
134
clientapi/admin_test.go
Normal file
|
|
@ -0,0 +1,134 @@
|
|||
package clientapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
"github.com/matrix-org/dendrite/keyserver"
|
||||
"github.com/matrix-org/dendrite/roomserver"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
"github.com/tidwall/gjson"
|
||||
|
||||
"github.com/matrix-org/dendrite/test"
|
||||
"github.com/matrix-org/dendrite/test/testrig"
|
||||
"github.com/matrix-org/dendrite/userapi"
|
||||
uapi "github.com/matrix-org/dendrite/userapi/api"
|
||||
)
|
||||
|
||||
func TestAdminResetPassword(t *testing.T) {
|
||||
aliceAdmin := test.NewUser(t, test.WithAccountType(uapi.AccountTypeAdmin))
|
||||
bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser))
|
||||
vhUser := &test.User{ID: "@vhuser:vh1"}
|
||||
|
||||
ctx := context.Background()
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
base, baseClose := testrig.CreateBaseDendrite(t, dbType)
|
||||
defer baseClose()
|
||||
|
||||
// add a vhost
|
||||
base.Cfg.Global.VirtualHosts = append(base.Cfg.Global.VirtualHosts, &config.VirtualHost{
|
||||
SigningIdentity: gomatrixserverlib.SigningIdentity{ServerName: "vh1"},
|
||||
})
|
||||
|
||||
rsAPI := roomserver.NewInternalAPI(base)
|
||||
// Needed for changing the password/login
|
||||
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI)
|
||||
userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil)
|
||||
keyAPI.SetUserAPI(userAPI)
|
||||
// We mostly need the userAPI for this test, so nil for other APIs/caches etc.
|
||||
AddPublicRoutes(base, nil, nil, nil, nil, nil, userAPI, nil, nil, nil)
|
||||
|
||||
// Create the users in the userapi and login
|
||||
accessTokens := map[*test.User]string{
|
||||
aliceAdmin: "",
|
||||
bob: "",
|
||||
vhUser: "",
|
||||
}
|
||||
for u := range accessTokens {
|
||||
localpart, serverName, _ := gomatrixserverlib.SplitID('@', u.ID)
|
||||
userRes := &uapi.PerformAccountCreationResponse{}
|
||||
password := util.RandomString(8)
|
||||
if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{
|
||||
AccountType: u.AccountType,
|
||||
Localpart: localpart,
|
||||
ServerName: serverName,
|
||||
Password: password,
|
||||
}, userRes); err != nil {
|
||||
t.Errorf("failed to create account: %s", err)
|
||||
}
|
||||
|
||||
req := test.NewRequest(t, http.MethodPost, "/_matrix/client/v3/login", test.WithJSONBody(t, map[string]interface{}{
|
||||
"type": authtypes.LoginTypePassword,
|
||||
"identifier": map[string]interface{}{
|
||||
"type": "m.id.user",
|
||||
"user": u.ID,
|
||||
},
|
||||
"password": password,
|
||||
}))
|
||||
rec := httptest.NewRecorder()
|
||||
base.PublicClientAPIMux.ServeHTTP(rec, req)
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("failed to login: %s", rec.Body.String())
|
||||
}
|
||||
accessTokens[u] = gjson.GetBytes(rec.Body.Bytes(), "access_token").String()
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
requestingUser *test.User
|
||||
userID string
|
||||
requestOpt test.HTTPRequestOpt
|
||||
wantOK bool
|
||||
withHeader bool
|
||||
}{
|
||||
{name: "Missing auth", requestingUser: bob, wantOK: false, userID: bob.ID},
|
||||
{name: "Bob is denied access", requestingUser: bob, wantOK: false, withHeader: true, userID: bob.ID},
|
||||
{name: "Alice is allowed access", requestingUser: aliceAdmin, wantOK: true, withHeader: true, userID: bob.ID, requestOpt: test.WithJSONBody(t, map[string]interface{}{
|
||||
"password": util.RandomString(8),
|
||||
})},
|
||||
{name: "missing userID does not call function", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: ""}, // this 404s
|
||||
{name: "rejects empty password", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: bob.ID, requestOpt: test.WithJSONBody(t, map[string]interface{}{
|
||||
"password": "",
|
||||
})},
|
||||
{name: "rejects unknown server name", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: "@doesnotexist:localhost", requestOpt: test.WithJSONBody(t, map[string]interface{}{})},
|
||||
{name: "rejects unknown user", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: "@doesnotexist:test", requestOpt: test.WithJSONBody(t, map[string]interface{}{})},
|
||||
{name: "allows changing password for different vhost", requestingUser: aliceAdmin, wantOK: true, withHeader: true, userID: vhUser.ID, requestOpt: test.WithJSONBody(t, map[string]interface{}{
|
||||
"password": util.RandomString(8),
|
||||
})},
|
||||
{name: "rejects existing user, missing body", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: bob.ID},
|
||||
{name: "rejects invalid userID", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: "!notauserid:test", requestOpt: test.WithJSONBody(t, map[string]interface{}{})},
|
||||
{name: "rejects invalid json", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: bob.ID, requestOpt: test.WithJSONBody(t, `{invalidJSON}`)},
|
||||
{name: "rejects too weak password", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: bob.ID, requestOpt: test.WithJSONBody(t, map[string]interface{}{
|
||||
"password": util.RandomString(6),
|
||||
})},
|
||||
{name: "rejects too long password", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: bob.ID, requestOpt: test.WithJSONBody(t, map[string]interface{}{
|
||||
"password": util.RandomString(513),
|
||||
})},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := test.NewRequest(t, http.MethodPost, "/_dendrite/admin/resetPassword/"+tc.userID)
|
||||
if tc.requestOpt != nil {
|
||||
req = test.NewRequest(t, http.MethodPost, "/_dendrite/admin/resetPassword/"+tc.userID, tc.requestOpt)
|
||||
}
|
||||
|
||||
if tc.withHeader {
|
||||
req.Header.Set("Authorization", "Bearer "+accessTokens[tc.requestingUser])
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
base.DendriteAdminMux.ServeHTTP(rec, req)
|
||||
t.Logf("%s", rec.Body.String())
|
||||
if tc.wantOK && rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected http status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
@ -15,6 +15,8 @@
|
|||
package clientapi
|
||||
|
||||
import (
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
||||
appserviceAPI "github.com/matrix-org/dendrite/appservice/api"
|
||||
"github.com/matrix-org/dendrite/clientapi/api"
|
||||
"github.com/matrix-org/dendrite/clientapi/producers"
|
||||
|
|
@ -26,7 +28,6 @@ import (
|
|||
"github.com/matrix-org/dendrite/setup/base"
|
||||
"github.com/matrix-org/dendrite/setup/jetstream"
|
||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
// AddPublicRoutes sets up and registers HTTP handlers for the ClientAPI component.
|
||||
|
|
@ -57,10 +58,7 @@ func AddPublicRoutes(
|
|||
}
|
||||
|
||||
routing.Setup(
|
||||
base.PublicClientAPIMux,
|
||||
base.PublicWellKnownAPIMux,
|
||||
base.SynapseAdminMux,
|
||||
base.DendriteAdminMux,
|
||||
base,
|
||||
cfg, rsAPI, asAPI,
|
||||
userAPI, userDirectoryProvider, federation,
|
||||
syncProducer, transactionsCache, fsAPI, keyAPI,
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
"github.com/nats-io/nats.go"
|
||||
|
|
@ -98,20 +99,40 @@ func AdminEvacuateUser(req *http.Request, cfg *config.ClientAPI, device *userapi
|
|||
}
|
||||
|
||||
func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse {
|
||||
if req.Body == nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.Unknown("Missing request body"),
|
||||
}
|
||||
}
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
serverName := cfg.Matrix.ServerName
|
||||
localpart, ok := vars["localpart"]
|
||||
if !ok {
|
||||
var localpart string
|
||||
userID := vars["userID"]
|
||||
localpart, serverName, err := cfg.Matrix.SplitLocalID('@', userID)
|
||||
if err != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.MissingArgument("Expecting user localpart."),
|
||||
JSON: jsonerror.InvalidArgumentValue(err.Error()),
|
||||
}
|
||||
}
|
||||
if l, s, err := cfg.Matrix.SplitLocalID('@', localpart); err == nil {
|
||||
localpart, serverName = l, s
|
||||
accAvailableResp := &userapi.QueryAccountAvailabilityResponse{}
|
||||
if err = userAPI.QueryAccountAvailability(req.Context(), &userapi.QueryAccountAvailabilityRequest{
|
||||
Localpart: localpart,
|
||||
ServerName: serverName,
|
||||
}, accAvailableResp); err != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
JSON: jsonerror.InternalAPIError(req.Context(), err),
|
||||
}
|
||||
}
|
||||
if accAvailableResp.Available {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusNotFound,
|
||||
JSON: jsonerror.Unknown("User does not exist"),
|
||||
}
|
||||
}
|
||||
request := struct {
|
||||
Password string `json:"password"`
|
||||
|
|
@ -128,6 +149,11 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap
|
|||
JSON: jsonerror.MissingArgument("Expecting non-empty password."),
|
||||
}
|
||||
}
|
||||
|
||||
if resErr := internal.ValidatePassword(request.Password); resErr != nil {
|
||||
return *resErr
|
||||
}
|
||||
|
||||
updateReq := &userapi.PerformPasswordUpdateRequest{
|
||||
Localpart: localpart,
|
||||
ServerName: serverName,
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ func JoinRoomByIDOrAlias(
|
|||
joinReq := roomserverAPI.PerformJoinRequest{
|
||||
RoomIDOrAlias: roomIDOrAlias,
|
||||
UserID: device.UserID,
|
||||
IsGuest: device.AccountType == api.AccountTypeGuest,
|
||||
Content: map[string]interface{}{},
|
||||
}
|
||||
joinRes := roomserverAPI.PerformJoinResponse{}
|
||||
|
|
@ -84,7 +85,14 @@ func JoinRoomByIDOrAlias(
|
|||
if err := rsAPI.PerformJoin(req.Context(), &joinReq, &joinRes); err != nil {
|
||||
done <- jsonerror.InternalAPIError(req.Context(), err)
|
||||
} else if joinRes.Error != nil {
|
||||
done <- joinRes.Error.JSONResponse()
|
||||
if joinRes.Error.Code == roomserverAPI.PerformErrorNotAllowed && device.AccountType == api.AccountTypeGuest {
|
||||
done <- util.JSONResponse{
|
||||
Code: http.StatusForbidden,
|
||||
JSON: jsonerror.GuestAccessForbidden(joinRes.Error.Msg),
|
||||
}
|
||||
} else {
|
||||
done <- joinRes.Error.JSONResponse()
|
||||
}
|
||||
} else {
|
||||
done <- util.JSONResponse{
|
||||
Code: http.StatusOK,
|
||||
|
|
|
|||
158
clientapi/routing/joinroom_test.go
Normal file
158
clientapi/routing/joinroom_test.go
Normal file
|
|
@ -0,0 +1,158 @@
|
|||
package routing
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
||||
"github.com/matrix-org/dendrite/appservice"
|
||||
"github.com/matrix-org/dendrite/keyserver"
|
||||
"github.com/matrix-org/dendrite/roomserver"
|
||||
"github.com/matrix-org/dendrite/test"
|
||||
"github.com/matrix-org/dendrite/test/testrig"
|
||||
"github.com/matrix-org/dendrite/userapi"
|
||||
uapi "github.com/matrix-org/dendrite/userapi/api"
|
||||
)
|
||||
|
||||
func TestJoinRoomByIDOrAlias(t *testing.T) {
|
||||
alice := test.NewUser(t)
|
||||
bob := test.NewUser(t)
|
||||
charlie := test.NewUser(t, test.WithAccountType(uapi.AccountTypeGuest))
|
||||
|
||||
ctx := context.Background()
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
base, baseClose := testrig.CreateBaseDendrite(t, dbType)
|
||||
defer baseClose()
|
||||
|
||||
rsAPI := roomserver.NewInternalAPI(base)
|
||||
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI)
|
||||
userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil)
|
||||
asAPI := appservice.NewInternalAPI(base, userAPI, rsAPI)
|
||||
rsAPI.SetFederationAPI(nil, nil) // creates the rs.Inputer etc
|
||||
|
||||
// Create the users in the userapi
|
||||
for _, u := range []*test.User{alice, bob, charlie} {
|
||||
localpart, serverName, _ := gomatrixserverlib.SplitID('@', u.ID)
|
||||
userRes := &uapi.PerformAccountCreationResponse{}
|
||||
if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{
|
||||
AccountType: u.AccountType,
|
||||
Localpart: localpart,
|
||||
ServerName: serverName,
|
||||
Password: "someRandomPassword",
|
||||
}, userRes); err != nil {
|
||||
t.Errorf("failed to create account: %s", err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
aliceDev := &uapi.Device{UserID: alice.ID}
|
||||
bobDev := &uapi.Device{UserID: bob.ID}
|
||||
charlieDev := &uapi.Device{UserID: charlie.ID, AccountType: uapi.AccountTypeGuest}
|
||||
|
||||
// create a room with disabled guest access and invite Bob
|
||||
resp := createRoom(ctx, createRoomRequest{
|
||||
Name: "testing",
|
||||
IsDirect: true,
|
||||
Topic: "testing",
|
||||
Visibility: "public",
|
||||
Preset: presetPublicChat,
|
||||
RoomAliasName: "alias",
|
||||
Invite: []string{bob.ID},
|
||||
GuestCanJoin: false,
|
||||
}, aliceDev, &base.Cfg.ClientAPI, userAPI, rsAPI, asAPI, time.Now())
|
||||
crResp, ok := resp.JSON.(createRoomResponse)
|
||||
if !ok {
|
||||
t.Fatalf("response is not a createRoomResponse: %+v", resp)
|
||||
}
|
||||
|
||||
// create a room with guest access enabled and invite Charlie
|
||||
resp = createRoom(ctx, createRoomRequest{
|
||||
Name: "testing",
|
||||
IsDirect: true,
|
||||
Topic: "testing",
|
||||
Visibility: "public",
|
||||
Preset: presetPublicChat,
|
||||
Invite: []string{charlie.ID},
|
||||
GuestCanJoin: true,
|
||||
}, aliceDev, &base.Cfg.ClientAPI, userAPI, rsAPI, asAPI, time.Now())
|
||||
crRespWithGuestAccess, ok := resp.JSON.(createRoomResponse)
|
||||
if !ok {
|
||||
t.Fatalf("response is not a createRoomResponse: %+v", resp)
|
||||
}
|
||||
|
||||
// Dummy request
|
||||
body := &bytes.Buffer{}
|
||||
req, err := http.NewRequest(http.MethodPost, "/?server_name=test", body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
device *uapi.Device
|
||||
roomID string
|
||||
wantHTTP200 bool
|
||||
}{
|
||||
{
|
||||
name: "User can join successfully by alias",
|
||||
device: bobDev,
|
||||
roomID: crResp.RoomAlias,
|
||||
wantHTTP200: true,
|
||||
},
|
||||
{
|
||||
name: "User can join successfully by roomID",
|
||||
device: bobDev,
|
||||
roomID: crResp.RoomID,
|
||||
wantHTTP200: true,
|
||||
},
|
||||
{
|
||||
name: "join is forbidden if user is guest",
|
||||
device: charlieDev,
|
||||
roomID: crResp.RoomID,
|
||||
},
|
||||
{
|
||||
name: "room does not exist",
|
||||
device: aliceDev,
|
||||
roomID: "!doesnotexist:test",
|
||||
},
|
||||
{
|
||||
name: "user from different server",
|
||||
device: &uapi.Device{UserID: "@wrong:server"},
|
||||
roomID: crResp.RoomAlias,
|
||||
},
|
||||
{
|
||||
name: "user doesn't exist locally",
|
||||
device: &uapi.Device{UserID: "@doesnotexist:test"},
|
||||
roomID: crResp.RoomAlias,
|
||||
},
|
||||
{
|
||||
name: "invalid room ID",
|
||||
device: aliceDev,
|
||||
roomID: "invalidRoomID",
|
||||
},
|
||||
{
|
||||
name: "roomAlias does not exist",
|
||||
device: aliceDev,
|
||||
roomID: "#doesnotexist:test",
|
||||
},
|
||||
{
|
||||
name: "room with guest_access event",
|
||||
device: charlieDev,
|
||||
roomID: crRespWithGuestAccess.RoomID,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
joinResp := JoinRoomByIDOrAlias(req, tc.device, rsAPI, userAPI, tc.roomID)
|
||||
if tc.wantHTTP200 && !joinResp.Is2xx() {
|
||||
t.Fatalf("expected join room to succeed, but didn't: %+v", joinResp)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
@ -7,6 +7,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
"github.com/matrix-org/dendrite/clientapi/httputil"
|
||||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
|
@ -81,7 +82,7 @@ func Password(
|
|||
sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword)
|
||||
|
||||
// Check the new password strength.
|
||||
if resErr = validatePassword(r.NewPassword); resErr != nil {
|
||||
if resErr = internal.ValidatePassword(r.NewPassword); resErr != nil {
|
||||
return *resErr
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/tidwall/gjson"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/eventutil"
|
||||
|
|
@ -60,8 +61,6 @@ var (
|
|||
)
|
||||
|
||||
const (
|
||||
minPasswordLength = 8 // http://matrix.org/docs/spec/client_server/r0.2.0.html#password-based
|
||||
maxPasswordLength = 512 // https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161
|
||||
maxUsernameLength = 254 // http://matrix.org/speculator/spec/HEAD/intro.html#user-identifiers TODO account for domain
|
||||
sessionIDLength = 24
|
||||
)
|
||||
|
|
@ -315,23 +314,6 @@ func validateApplicationServiceUsername(localpart string, domain gomatrixserverl
|
|||
return nil
|
||||
}
|
||||
|
||||
// validatePassword returns an error response if the password is invalid
|
||||
func validatePassword(password string) *util.JSONResponse {
|
||||
// https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161
|
||||
if len(password) > maxPasswordLength {
|
||||
return &util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.BadJSON(fmt.Sprintf("'password' >%d characters", maxPasswordLength)),
|
||||
}
|
||||
} else if len(password) > 0 && len(password) < minPasswordLength {
|
||||
return &util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.WeakPassword(fmt.Sprintf("password too weak: min %d chars", minPasswordLength)),
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateRecaptcha returns an error response if the captcha response is invalid
|
||||
func validateRecaptcha(
|
||||
cfg *config.ClientAPI,
|
||||
|
|
@ -636,7 +618,7 @@ func Register(
|
|||
return *resErr
|
||||
}
|
||||
}
|
||||
if resErr := validatePassword(r.Password); resErr != nil {
|
||||
if resErr := internal.ValidatePassword(r.Password); resErr != nil {
|
||||
return *resErr
|
||||
}
|
||||
|
||||
|
|
@ -1138,7 +1120,7 @@ func handleSharedSecretRegistration(cfg *config.ClientAPI, userAPI userapi.Clien
|
|||
if resErr := validateUsername(ssrr.User, cfg.Matrix.ServerName); resErr != nil {
|
||||
return *resErr
|
||||
}
|
||||
if resErr := validatePassword(ssrr.Password); resErr != nil {
|
||||
if resErr := internal.ValidatePassword(ssrr.Password); resErr != nil {
|
||||
return *resErr
|
||||
}
|
||||
deviceID := "shared_secret_registration"
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ import (
|
|||
"strings"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/matrix-org/dendrite/setup/base"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
"github.com/nats-io/nats.go"
|
||||
|
|
@ -49,7 +50,7 @@ import (
|
|||
// applied:
|
||||
// nolint: gocyclo
|
||||
func Setup(
|
||||
publicAPIMux, wkMux, synapseAdminRouter, dendriteAdminRouter *mux.Router,
|
||||
base *base.BaseDendrite,
|
||||
cfg *config.ClientAPI,
|
||||
rsAPI roomserverAPI.ClientRoomserverAPI,
|
||||
asAPI appserviceAPI.AppServiceInternalAPI,
|
||||
|
|
@ -63,7 +64,14 @@ func Setup(
|
|||
extRoomsProvider api.ExtraPublicRoomsProvider,
|
||||
mscCfg *config.MSCs, natsClient *nats.Conn,
|
||||
) {
|
||||
prometheus.MustRegister(amtRegUsers, sendEventDuration)
|
||||
publicAPIMux := base.PublicClientAPIMux
|
||||
wkMux := base.PublicWellKnownAPIMux
|
||||
synapseAdminRouter := base.SynapseAdminMux
|
||||
dendriteAdminRouter := base.DendriteAdminMux
|
||||
|
||||
if base.EnableMetrics {
|
||||
prometheus.MustRegister(amtRegUsers, sendEventDuration)
|
||||
}
|
||||
|
||||
rateLimits := httputil.NewRateLimits(&cfg.RateLimiting)
|
||||
userInteractiveAuth := auth.NewUserInteractive(userAPI, cfg)
|
||||
|
|
@ -631,7 +639,7 @@ func Setup(
|
|||
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/auth/{authType}/fallback/web",
|
||||
httputil.MakeHTMLAPI("auth_fallback", func(w http.ResponseWriter, req *http.Request) *util.JSONResponse {
|
||||
httputil.MakeHTMLAPI("auth_fallback", base.EnableMetrics, func(w http.ResponseWriter, req *http.Request) *util.JSONResponse {
|
||||
vars := mux.Vars(req)
|
||||
return AuthFallback(w, req, vars["authType"], cfg)
|
||||
}),
|
||||
|
|
|
|||
|
|
@ -213,7 +213,7 @@ func main() {
|
|||
base, federation, rsAPI, base.Caches, keyRing, true,
|
||||
)
|
||||
|
||||
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI)
|
||||
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI, rsComponent)
|
||||
userAPI := userapi.NewInternalAPI(base, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient())
|
||||
keyAPI.SetUserAPI(userAPI)
|
||||
|
||||
|
|
|
|||
|
|
@ -157,11 +157,12 @@ func main() {
|
|||
serverKeyAPI := &signing.YggdrasilKeys{}
|
||||
keyRing := serverKeyAPI.KeyRing()
|
||||
|
||||
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation)
|
||||
|
||||
rsComponent := roomserver.NewInternalAPI(
|
||||
base,
|
||||
)
|
||||
|
||||
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation, rsComponent)
|
||||
|
||||
rsAPI := rsComponent
|
||||
|
||||
userAPI := userapi.NewInternalAPI(base, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient())
|
||||
|
|
|
|||
|
|
@ -95,7 +95,7 @@ func main() {
|
|||
}
|
||||
keyRing := fsAPI.KeyRing()
|
||||
|
||||
keyImpl := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI)
|
||||
keyImpl := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI, rsAPI)
|
||||
keyAPI := keyImpl
|
||||
if base.UseHTTPAPIs {
|
||||
keyserver.AddInternalRoutes(base.InternalAPIMux, keyAPI, base.EnableMetrics)
|
||||
|
|
|
|||
|
|
@ -22,7 +22,8 @@ import (
|
|||
|
||||
func KeyServer(base *basepkg.BaseDendrite, cfg *config.Dendrite) {
|
||||
fsAPI := base.FederationAPIHTTPClient()
|
||||
intAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI)
|
||||
rsAPI := base.RoomserverHTTPClient()
|
||||
intAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI, rsAPI)
|
||||
intAPI.SetUserAPI(base.UserAPIClient())
|
||||
|
||||
keyserver.AddInternalRoutes(base.InternalAPIMux, intAPI, base.EnableMetrics)
|
||||
|
|
|
|||
|
|
@ -91,7 +91,7 @@ Please use PostgreSQL wherever possible, especially if you are planning to run a
|
|||
## Dendrite is using a lot of CPU
|
||||
|
||||
Generally speaking, you should expect to see some CPU spikes, particularly if you are joining or participating in large rooms. However, constant/sustained high CPU usage is not expected - if you are experiencing that, please join `#dendrite-dev:matrix.org` and let us know what you were doing when the
|
||||
CPU usage shot up, or file a GitHub issue. If you can take a [CPU profile](PROFILING.md) then that would
|
||||
CPU usage shot up, or file a GitHub issue. If you can take a [CPU profile](development/PROFILING.md) then that would
|
||||
be a huge help too, as that will help us to understand where the CPU time is going.
|
||||
|
||||
## Dendrite is using a lot of RAM
|
||||
|
|
@ -99,7 +99,7 @@ be a huge help too, as that will help us to understand where the CPU time is goi
|
|||
As above with CPU usage, some memory spikes are expected if Dendrite is doing particularly heavy work
|
||||
at a given instant. However, if it is using more RAM than you expect for a long time, that's probably
|
||||
not expected. Join `#dendrite-dev:matrix.org` and let us know what you were doing when the memory usage
|
||||
ballooned, or file a GitHub issue if you can. If you can take a [memory profile](PROFILING.md) then that
|
||||
ballooned, or file a GitHub issue if you can. If you can take a [memory profile](development/PROFILING.md) then that
|
||||
would be a huge help too, as that will help us to understand where the memory usage is happening.
|
||||
|
||||
## Dendrite is running out of PostgreSQL database connections
|
||||
|
|
|
|||
|
|
@ -231,9 +231,9 @@ GEM
|
|||
jekyll-seo-tag (~> 2.1)
|
||||
minitest (5.15.0)
|
||||
multipart-post (2.1.1)
|
||||
nokogiri (1.13.9-arm64-darwin)
|
||||
nokogiri (1.13.10-arm64-darwin)
|
||||
racc (~> 1.4)
|
||||
nokogiri (1.13.9-x86_64-linux)
|
||||
nokogiri (1.13.10-x86_64-linux)
|
||||
racc (~> 1.4)
|
||||
octokit (4.22.0)
|
||||
faraday (>= 0.9)
|
||||
|
|
@ -241,7 +241,7 @@ GEM
|
|||
pathutil (0.16.2)
|
||||
forwardable-extended (~> 2.6)
|
||||
public_suffix (4.0.7)
|
||||
racc (1.6.0)
|
||||
racc (1.6.1)
|
||||
rb-fsevent (0.11.1)
|
||||
rb-inotify (0.10.1)
|
||||
ffi (~> 1.0)
|
||||
|
|
|
|||
|
|
@ -44,7 +44,9 @@ This endpoint will instruct Dendrite to part the given local `userID` in the URL
|
|||
all rooms which they are currently joined. A JSON body will be returned containing
|
||||
the room IDs of all affected rooms.
|
||||
|
||||
## POST `/_dendrite/admin/resetPassword/{localpart}`
|
||||
## POST `/_dendrite/admin/resetPassword/{userID}`
|
||||
|
||||
Reset the password of a local user.
|
||||
|
||||
Request body format:
|
||||
|
||||
|
|
@ -54,9 +56,6 @@ Request body format:
|
|||
}
|
||||
```
|
||||
|
||||
Reset the password of a local user. The `localpart` is the username only, i.e. if
|
||||
the full user ID is `@alice:domain.com` then the local part is `alice`.
|
||||
|
||||
## GET `/_dendrite/admin/fulltext/reindex`
|
||||
|
||||
This endpoint instructs Dendrite to reindex all searchable events (`m.room.message`, `m.room.topic` and `m.room.name`). An empty JSON body will be returned immediately.
|
||||
|
|
|
|||
|
|
@ -9,6 +9,28 @@ permalink: /development/contributing
|
|||
Everyone is welcome to contribute to Dendrite! We aim to make it as easy as
|
||||
possible to get started.
|
||||
|
||||
## Contribution types
|
||||
|
||||
We are a small team maintaining a large project. As a result, we cannot merge every feature, even if it
|
||||
is bug-free and useful, because we then commit to maintaining it indefinitely. We will always accept:
|
||||
- bug fixes
|
||||
- security fixes (please responsibly disclose via security@matrix.org *before* creating pull requests)
|
||||
|
||||
We will accept the following with caveats:
|
||||
- documentation fixes, provided they do not add additional instructions which can end up going out-of-date,
|
||||
e.g example configs, shell commands.
|
||||
- performance fixes, provided they do not add significantly more maintenance burden.
|
||||
- additional functionality on existing features, provided the functionality is small and maintainable.
|
||||
- additional functionality that, in its absence, would impact the ecosystem e.g spam and abuse mitigations
|
||||
- test-only changes, provided they help improve coverage or test tricky code.
|
||||
|
||||
The following items are at risk of not being accepted:
|
||||
- Configuration or CLI changes, particularly ones which increase the overall configuration surface.
|
||||
|
||||
The following items are unlikely to be accepted into a main Dendrite release for now:
|
||||
- New MSC implementations.
|
||||
- New features which are not in the specification.
|
||||
|
||||
## Sign off
|
||||
|
||||
We require that everyone who contributes to the project signs off their contributions
|
||||
|
|
@ -35,7 +57,7 @@ to do so for future contributions.
|
|||
|
||||
## Getting up and running
|
||||
|
||||
See the [Installation](installation) section for information on how to build an
|
||||
See the [Installation](../installation) section for information on how to build an
|
||||
instance of Dendrite. You will likely need this in order to test your changes.
|
||||
|
||||
## Code style
|
||||
|
|
@ -129,7 +151,7 @@ significant amount of CPU and RAM.
|
|||
|
||||
Once the code builds, run [Sytest](https://github.com/matrix-org/sytest)
|
||||
according to the guide in
|
||||
[docs/sytest.md](https://github.com/matrix-org/dendrite/blob/main/docs/sytest.md#using-a-sytest-docker-image)
|
||||
[docs/development/sytest.md](https://github.com/matrix-org/dendrite/blob/main/docs/development/sytest.md#using-a-sytest-docker-image)
|
||||
so you can see whether something is being broken and whether there are newly
|
||||
passing tests.
|
||||
|
||||
|
|
@ -232,7 +232,7 @@ func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent, rew
|
|||
}
|
||||
|
||||
func (s *OutputRoomEventConsumer) sendPresence(roomID string, addedJoined []types.JoinedHost) {
|
||||
joined := make([]gomatrixserverlib.ServerName, len(addedJoined))
|
||||
joined := make([]gomatrixserverlib.ServerName, 0, len(addedJoined))
|
||||
for _, added := range addedJoined {
|
||||
joined = append(joined, added.ServerName)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -221,28 +221,6 @@ func (d *fakeDatabase) CleanEDUs(ctx context.Context, serverName gomatrixserverl
|
|||
return nil
|
||||
}
|
||||
|
||||
func (d *fakeDatabase) GetPendingPDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error) {
|
||||
d.dbMutex.Lock()
|
||||
defer d.dbMutex.Unlock()
|
||||
|
||||
var count int64
|
||||
if pdus, ok := d.associatedPDUs[serverName]; ok {
|
||||
count = int64(len(pdus))
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (d *fakeDatabase) GetPendingEDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error) {
|
||||
d.dbMutex.Lock()
|
||||
defer d.dbMutex.Unlock()
|
||||
|
||||
var count int64
|
||||
if edus, ok := d.associatedEDUs[serverName]; ok {
|
||||
count = int64(len(edus))
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (d *fakeDatabase) GetPendingPDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) {
|
||||
d.dbMutex.Lock()
|
||||
defer d.dbMutex.Unlock()
|
||||
|
|
|
|||
|
|
@ -45,9 +45,6 @@ type Database interface {
|
|||
CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error
|
||||
CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error
|
||||
|
||||
GetPendingPDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error)
|
||||
GetPendingEDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error)
|
||||
|
||||
GetPendingPDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
|
||||
GetPendingEDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
|
||||
|
||||
|
|
|
|||
|
|
@ -44,13 +44,13 @@ const selectInboundPeekSQL = "" +
|
|||
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3"
|
||||
|
||||
const selectInboundPeeksSQL = "" +
|
||||
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1"
|
||||
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1 ORDER by creation_ts"
|
||||
|
||||
const renewInboundPeekSQL = "" +
|
||||
"UPDATE federationsender_inbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5"
|
||||
|
||||
const deleteInboundPeekSQL = "" +
|
||||
"DELETE FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2"
|
||||
"DELETE FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3"
|
||||
|
||||
const deleteInboundPeeksSQL = "" +
|
||||
"DELETE FROM federationsender_inbound_peeks WHERE room_id = $1"
|
||||
|
|
@ -74,25 +74,15 @@ func NewPostgresInboundPeeksTable(db *sql.DB) (s *inboundPeeksStatements, err er
|
|||
return
|
||||
}
|
||||
|
||||
if s.insertInboundPeekStmt, err = db.Prepare(insertInboundPeekSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectInboundPeekStmt, err = db.Prepare(selectInboundPeekSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectInboundPeeksStmt, err = db.Prepare(selectInboundPeeksSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.renewInboundPeekStmt, err = db.Prepare(renewInboundPeekSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.deleteInboundPeeksStmt, err = db.Prepare(deleteInboundPeeksSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.deleteInboundPeekStmt, err = db.Prepare(deleteInboundPeekSQL); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.insertInboundPeekStmt, insertInboundPeekSQL},
|
||||
{&s.selectInboundPeekStmt, selectInboundPeekSQL},
|
||||
{&s.selectInboundPeekStmt, selectInboundPeekSQL},
|
||||
{&s.selectInboundPeeksStmt, selectInboundPeeksSQL},
|
||||
{&s.renewInboundPeekStmt, renewInboundPeekSQL},
|
||||
{&s.deleteInboundPeeksStmt, deleteInboundPeeksSQL},
|
||||
{&s.deleteInboundPeekStmt, deleteInboundPeekSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *inboundPeeksStatements) InsertInboundPeek(
|
||||
|
|
|
|||
|
|
@ -44,13 +44,13 @@ const selectOutboundPeekSQL = "" +
|
|||
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3"
|
||||
|
||||
const selectOutboundPeeksSQL = "" +
|
||||
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1"
|
||||
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1 ORDER BY creation_ts"
|
||||
|
||||
const renewOutboundPeekSQL = "" +
|
||||
"UPDATE federationsender_outbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5"
|
||||
|
||||
const deleteOutboundPeekSQL = "" +
|
||||
"DELETE FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2"
|
||||
"DELETE FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3"
|
||||
|
||||
const deleteOutboundPeeksSQL = "" +
|
||||
"DELETE FROM federationsender_outbound_peeks WHERE room_id = $1"
|
||||
|
|
@ -74,25 +74,14 @@ func NewPostgresOutboundPeeksTable(db *sql.DB) (s *outboundPeeksStatements, err
|
|||
return
|
||||
}
|
||||
|
||||
if s.insertOutboundPeekStmt, err = db.Prepare(insertOutboundPeekSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectOutboundPeekStmt, err = db.Prepare(selectOutboundPeekSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectOutboundPeeksStmt, err = db.Prepare(selectOutboundPeeksSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.renewOutboundPeekStmt, err = db.Prepare(renewOutboundPeekSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.deleteOutboundPeeksStmt, err = db.Prepare(deleteOutboundPeeksSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.deleteOutboundPeekStmt, err = db.Prepare(deleteOutboundPeekSQL); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.insertOutboundPeekStmt, insertOutboundPeekSQL},
|
||||
{&s.selectOutboundPeekStmt, selectOutboundPeekSQL},
|
||||
{&s.selectOutboundPeeksStmt, selectOutboundPeeksSQL},
|
||||
{&s.renewOutboundPeekStmt, renewOutboundPeekSQL},
|
||||
{&s.deleteOutboundPeeksStmt, deleteOutboundPeeksSQL},
|
||||
{&s.deleteOutboundPeekStmt, deleteOutboundPeekSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *outboundPeeksStatements) InsertOutboundPeek(
|
||||
|
|
|
|||
|
|
@ -62,10 +62,6 @@ const selectQueueEDUReferenceJSONCountSQL = "" +
|
|||
"SELECT COUNT(*) FROM federationsender_queue_edus" +
|
||||
" WHERE json_nid = $1"
|
||||
|
||||
const selectQueueEDUCountSQL = "" +
|
||||
"SELECT COUNT(*) FROM federationsender_queue_edus" +
|
||||
" WHERE server_name = $1"
|
||||
|
||||
const selectQueueServerNamesSQL = "" +
|
||||
"SELECT DISTINCT server_name FROM federationsender_queue_edus"
|
||||
|
||||
|
|
@ -81,7 +77,6 @@ type queueEDUsStatements struct {
|
|||
deleteQueueEDUStmt *sql.Stmt
|
||||
selectQueueEDUStmt *sql.Stmt
|
||||
selectQueueEDUReferenceJSONCountStmt *sql.Stmt
|
||||
selectQueueEDUCountStmt *sql.Stmt
|
||||
selectQueueEDUServerNamesStmt *sql.Stmt
|
||||
selectExpiredEDUsStmt *sql.Stmt
|
||||
deleteExpiredEDUsStmt *sql.Stmt
|
||||
|
|
@ -116,7 +111,6 @@ func (s *queueEDUsStatements) Prepare() error {
|
|||
{&s.deleteQueueEDUStmt, deleteQueueEDUSQL},
|
||||
{&s.selectQueueEDUStmt, selectQueueEDUSQL},
|
||||
{&s.selectQueueEDUReferenceJSONCountStmt, selectQueueEDUReferenceJSONCountSQL},
|
||||
{&s.selectQueueEDUCountStmt, selectQueueEDUCountSQL},
|
||||
{&s.selectQueueEDUServerNamesStmt, selectQueueServerNamesSQL},
|
||||
{&s.selectExpiredEDUsStmt, selectExpiredEDUsSQL},
|
||||
{&s.deleteExpiredEDUsStmt, deleteExpiredEDUsSQL},
|
||||
|
|
@ -186,21 +180,6 @@ func (s *queueEDUsStatements) SelectQueueEDUReferenceJSONCount(
|
|||
return count, err
|
||||
}
|
||||
|
||||
func (s *queueEDUsStatements) SelectQueueEDUCount(
|
||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
|
||||
) (int64, error) {
|
||||
var count int64
|
||||
stmt := sqlutil.TxStmt(txn, s.selectQueueEDUCountStmt)
|
||||
err := stmt.QueryRowContext(ctx, serverName).Scan(&count)
|
||||
if err == sql.ErrNoRows {
|
||||
// It's acceptable for there to be no rows referencing a given
|
||||
// JSON NID but it's not an error condition. Just return as if
|
||||
// there's a zero count.
|
||||
return 0, nil
|
||||
}
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (s *queueEDUsStatements) SelectQueueEDUServerNames(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
) ([]gomatrixserverlib.ServerName, error) {
|
||||
|
|
|
|||
|
|
@ -58,10 +58,6 @@ const selectQueuePDUReferenceJSONCountSQL = "" +
|
|||
"SELECT COUNT(*) FROM federationsender_queue_pdus" +
|
||||
" WHERE json_nid = $1"
|
||||
|
||||
const selectQueuePDUsCountSQL = "" +
|
||||
"SELECT COUNT(*) FROM federationsender_queue_pdus" +
|
||||
" WHERE server_name = $1"
|
||||
|
||||
const selectQueuePDUServerNamesSQL = "" +
|
||||
"SELECT DISTINCT server_name FROM federationsender_queue_pdus"
|
||||
|
||||
|
|
@ -71,7 +67,6 @@ type queuePDUsStatements struct {
|
|||
deleteQueuePDUsStmt *sql.Stmt
|
||||
selectQueuePDUsStmt *sql.Stmt
|
||||
selectQueuePDUReferenceJSONCountStmt *sql.Stmt
|
||||
selectQueuePDUsCountStmt *sql.Stmt
|
||||
selectQueuePDUServerNamesStmt *sql.Stmt
|
||||
}
|
||||
|
||||
|
|
@ -95,9 +90,6 @@ func NewPostgresQueuePDUsTable(db *sql.DB) (s *queuePDUsStatements, err error) {
|
|||
if s.selectQueuePDUReferenceJSONCountStmt, err = s.db.Prepare(selectQueuePDUReferenceJSONCountSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectQueuePDUsCountStmt, err = s.db.Prepare(selectQueuePDUsCountSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectQueuePDUServerNamesStmt, err = s.db.Prepare(selectQueuePDUServerNamesSQL); err != nil {
|
||||
return
|
||||
}
|
||||
|
|
@ -146,21 +138,6 @@ func (s *queuePDUsStatements) SelectQueuePDUReferenceJSONCount(
|
|||
return count, err
|
||||
}
|
||||
|
||||
func (s *queuePDUsStatements) SelectQueuePDUCount(
|
||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
|
||||
) (int64, error) {
|
||||
var count int64
|
||||
stmt := sqlutil.TxStmt(txn, s.selectQueuePDUsCountStmt)
|
||||
err := stmt.QueryRowContext(ctx, serverName).Scan(&count)
|
||||
if err == sql.ErrNoRows {
|
||||
// It's acceptable for there to be no rows referencing a given
|
||||
// JSON NID but it's not an error condition. Just return as if
|
||||
// there's a zero count.
|
||||
return 0, nil
|
||||
}
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (s *queuePDUsStatements) SelectQueuePDUs(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
|
|
|
|||
|
|
@ -162,15 +162,6 @@ func (d *Database) CleanEDUs(
|
|||
})
|
||||
}
|
||||
|
||||
// GetPendingEDUCount returns the number of EDUs waiting to be
|
||||
// sent for a given servername.
|
||||
func (d *Database) GetPendingEDUCount(
|
||||
ctx context.Context,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
) (int64, error) {
|
||||
return d.FederationQueueEDUs.SelectQueueEDUCount(ctx, nil, serverName)
|
||||
}
|
||||
|
||||
// GetPendingServerNames returns the server names that have EDUs
|
||||
// waiting to be sent.
|
||||
func (d *Database) GetPendingEDUServerNames(
|
||||
|
|
|
|||
|
|
@ -141,15 +141,6 @@ func (d *Database) CleanPDUs(
|
|||
})
|
||||
}
|
||||
|
||||
// GetPendingPDUCount returns the number of PDUs waiting to be
|
||||
// sent for a given servername.
|
||||
func (d *Database) GetPendingPDUCount(
|
||||
ctx context.Context,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
) (int64, error) {
|
||||
return d.FederationQueuePDUs.SelectQueuePDUCount(ctx, nil, serverName)
|
||||
}
|
||||
|
||||
// GetPendingServerNames returns the server names that have PDUs
|
||||
// waiting to be sent.
|
||||
func (d *Database) GetPendingPDUServerNames(
|
||||
|
|
|
|||
|
|
@ -44,13 +44,13 @@ const selectInboundPeekSQL = "" +
|
|||
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3"
|
||||
|
||||
const selectInboundPeeksSQL = "" +
|
||||
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1"
|
||||
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1 ORDER BY creation_ts"
|
||||
|
||||
const renewInboundPeekSQL = "" +
|
||||
"UPDATE federationsender_inbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5"
|
||||
|
||||
const deleteInboundPeekSQL = "" +
|
||||
"DELETE FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2"
|
||||
"DELETE FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3"
|
||||
|
||||
const deleteInboundPeeksSQL = "" +
|
||||
"DELETE FROM federationsender_inbound_peeks WHERE room_id = $1"
|
||||
|
|
@ -74,25 +74,15 @@ func NewSQLiteInboundPeeksTable(db *sql.DB) (s *inboundPeeksStatements, err erro
|
|||
return
|
||||
}
|
||||
|
||||
if s.insertInboundPeekStmt, err = db.Prepare(insertInboundPeekSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectInboundPeekStmt, err = db.Prepare(selectInboundPeekSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectInboundPeeksStmt, err = db.Prepare(selectInboundPeeksSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.renewInboundPeekStmt, err = db.Prepare(renewInboundPeekSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.deleteInboundPeeksStmt, err = db.Prepare(deleteInboundPeeksSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.deleteInboundPeekStmt, err = db.Prepare(deleteInboundPeekSQL); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.insertInboundPeekStmt, insertInboundPeekSQL},
|
||||
{&s.selectInboundPeekStmt, selectInboundPeekSQL},
|
||||
{&s.selectInboundPeekStmt, selectInboundPeekSQL},
|
||||
{&s.selectInboundPeeksStmt, selectInboundPeeksSQL},
|
||||
{&s.renewInboundPeekStmt, renewInboundPeekSQL},
|
||||
{&s.deleteInboundPeeksStmt, deleteInboundPeeksSQL},
|
||||
{&s.deleteInboundPeekStmt, deleteInboundPeekSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *inboundPeeksStatements) InsertInboundPeek(
|
||||
|
|
|
|||
|
|
@ -44,13 +44,13 @@ const selectOutboundPeekSQL = "" +
|
|||
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3"
|
||||
|
||||
const selectOutboundPeeksSQL = "" +
|
||||
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1"
|
||||
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1 ORDER BY creation_ts"
|
||||
|
||||
const renewOutboundPeekSQL = "" +
|
||||
"UPDATE federationsender_outbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5"
|
||||
|
||||
const deleteOutboundPeekSQL = "" +
|
||||
"DELETE FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2"
|
||||
"DELETE FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3"
|
||||
|
||||
const deleteOutboundPeeksSQL = "" +
|
||||
"DELETE FROM federationsender_outbound_peeks WHERE room_id = $1"
|
||||
|
|
@ -74,25 +74,14 @@ func NewSQLiteOutboundPeeksTable(db *sql.DB) (s *outboundPeeksStatements, err er
|
|||
return
|
||||
}
|
||||
|
||||
if s.insertOutboundPeekStmt, err = db.Prepare(insertOutboundPeekSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectOutboundPeekStmt, err = db.Prepare(selectOutboundPeekSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectOutboundPeeksStmt, err = db.Prepare(selectOutboundPeeksSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.renewOutboundPeekStmt, err = db.Prepare(renewOutboundPeekSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.deleteOutboundPeeksStmt, err = db.Prepare(deleteOutboundPeeksSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.deleteOutboundPeekStmt, err = db.Prepare(deleteOutboundPeekSQL); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.insertOutboundPeekStmt, insertOutboundPeekSQL},
|
||||
{&s.selectOutboundPeekStmt, selectOutboundPeekSQL},
|
||||
{&s.selectOutboundPeeksStmt, selectOutboundPeeksSQL},
|
||||
{&s.renewOutboundPeekStmt, renewOutboundPeekSQL},
|
||||
{&s.deleteOutboundPeeksStmt, deleteOutboundPeeksSQL},
|
||||
{&s.deleteOutboundPeekStmt, deleteOutboundPeekSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *outboundPeeksStatements) InsertOutboundPeek(
|
||||
|
|
|
|||
|
|
@ -63,10 +63,6 @@ const selectQueueEDUReferenceJSONCountSQL = "" +
|
|||
"SELECT COUNT(*) FROM federationsender_queue_edus" +
|
||||
" WHERE json_nid = $1"
|
||||
|
||||
const selectQueueEDUCountSQL = "" +
|
||||
"SELECT COUNT(*) FROM federationsender_queue_edus" +
|
||||
" WHERE server_name = $1"
|
||||
|
||||
const selectQueueServerNamesSQL = "" +
|
||||
"SELECT DISTINCT server_name FROM federationsender_queue_edus"
|
||||
|
||||
|
|
@ -82,7 +78,6 @@ type queueEDUsStatements struct {
|
|||
// deleteQueueEDUStmt *sql.Stmt - prepared at runtime due to variadic
|
||||
selectQueueEDUStmt *sql.Stmt
|
||||
selectQueueEDUReferenceJSONCountStmt *sql.Stmt
|
||||
selectQueueEDUCountStmt *sql.Stmt
|
||||
selectQueueEDUServerNamesStmt *sql.Stmt
|
||||
selectExpiredEDUsStmt *sql.Stmt
|
||||
deleteExpiredEDUsStmt *sql.Stmt
|
||||
|
|
@ -116,7 +111,6 @@ func (s *queueEDUsStatements) Prepare() error {
|
|||
{&s.insertQueueEDUStmt, insertQueueEDUSQL},
|
||||
{&s.selectQueueEDUStmt, selectQueueEDUSQL},
|
||||
{&s.selectQueueEDUReferenceJSONCountStmt, selectQueueEDUReferenceJSONCountSQL},
|
||||
{&s.selectQueueEDUCountStmt, selectQueueEDUCountSQL},
|
||||
{&s.selectQueueEDUServerNamesStmt, selectQueueServerNamesSQL},
|
||||
{&s.selectExpiredEDUsStmt, selectExpiredEDUsSQL},
|
||||
{&s.deleteExpiredEDUsStmt, deleteExpiredEDUsSQL},
|
||||
|
|
@ -198,21 +192,6 @@ func (s *queueEDUsStatements) SelectQueueEDUReferenceJSONCount(
|
|||
return count, err
|
||||
}
|
||||
|
||||
func (s *queueEDUsStatements) SelectQueueEDUCount(
|
||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
|
||||
) (int64, error) {
|
||||
var count int64
|
||||
stmt := sqlutil.TxStmt(txn, s.selectQueueEDUCountStmt)
|
||||
err := stmt.QueryRowContext(ctx, serverName).Scan(&count)
|
||||
if err == sql.ErrNoRows {
|
||||
// It's acceptable for there to be no rows referencing a given
|
||||
// JSON NID but it's not an error condition. Just return as if
|
||||
// there's a zero count.
|
||||
return 0, nil
|
||||
}
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (s *queueEDUsStatements) SelectQueueEDUServerNames(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
) ([]gomatrixserverlib.ServerName, error) {
|
||||
|
|
|
|||
|
|
@ -66,10 +66,6 @@ const selectQueuePDUsReferenceJSONCountSQL = "" +
|
|||
"SELECT COUNT(*) FROM federationsender_queue_pdus" +
|
||||
" WHERE json_nid = $1"
|
||||
|
||||
const selectQueuePDUsCountSQL = "" +
|
||||
"SELECT COUNT(*) FROM federationsender_queue_pdus" +
|
||||
" WHERE server_name = $1"
|
||||
|
||||
const selectQueuePDUsServerNamesSQL = "" +
|
||||
"SELECT DISTINCT server_name FROM federationsender_queue_pdus"
|
||||
|
||||
|
|
@ -79,7 +75,6 @@ type queuePDUsStatements struct {
|
|||
selectQueueNextTransactionIDStmt *sql.Stmt
|
||||
selectQueuePDUsStmt *sql.Stmt
|
||||
selectQueueReferenceJSONCountStmt *sql.Stmt
|
||||
selectQueuePDUsCountStmt *sql.Stmt
|
||||
selectQueueServerNamesStmt *sql.Stmt
|
||||
// deleteQueuePDUsStmt *sql.Stmt - prepared at runtime due to variadic
|
||||
}
|
||||
|
|
@ -107,9 +102,6 @@ func NewSQLiteQueuePDUsTable(db *sql.DB) (s *queuePDUsStatements, err error) {
|
|||
if s.selectQueueReferenceJSONCountStmt, err = db.Prepare(selectQueuePDUsReferenceJSONCountSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectQueuePDUsCountStmt, err = db.Prepare(selectQueuePDUsCountSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectQueueServerNamesStmt, err = db.Prepare(selectQueuePDUsServerNamesSQL); err != nil {
|
||||
return
|
||||
}
|
||||
|
|
@ -179,21 +171,6 @@ func (s *queuePDUsStatements) SelectQueuePDUReferenceJSONCount(
|
|||
return count, err
|
||||
}
|
||||
|
||||
func (s *queuePDUsStatements) SelectQueuePDUCount(
|
||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
|
||||
) (int64, error) {
|
||||
var count int64
|
||||
stmt := sqlutil.TxStmt(txn, s.selectQueuePDUsCountStmt)
|
||||
err := stmt.QueryRowContext(ctx, serverName).Scan(&count)
|
||||
if err == sql.ErrNoRows {
|
||||
// It's acceptable for there to be no rows referencing a given
|
||||
// JSON NID but it's not an error condition. Just return as if
|
||||
// there's a zero count.
|
||||
return 0, nil
|
||||
}
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (s *queuePDUsStatements) SelectQueuePDUs(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
|
|
|
|||
|
|
@ -2,10 +2,12 @@ package storage_test
|
|||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/matrix-org/dendrite/federationapi/storage"
|
||||
|
|
@ -80,3 +82,167 @@ func TestExpireEDUs(t *testing.T) {
|
|||
assert.Equal(t, 2, len(data))
|
||||
})
|
||||
}
|
||||
|
||||
func TestOutboundPeeking(t *testing.T) {
|
||||
alice := test.NewUser(t)
|
||||
room := test.NewRoom(t, alice)
|
||||
_, serverName, _ := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
ctx := context.Background()
|
||||
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, closeDB := mustCreateFederationDatabase(t, dbType)
|
||||
defer closeDB()
|
||||
peekID := util.RandomString(8)
|
||||
var renewalInterval int64 = 1000
|
||||
|
||||
// Add outbound peek
|
||||
if err := db.AddOutboundPeek(ctx, serverName, room.ID, peekID, renewalInterval); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// select the newly inserted peek
|
||||
outboundPeek1, err := db.GetOutboundPeek(ctx, serverName, room.ID, peekID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Assert fields are set as expected
|
||||
if outboundPeek1.PeekID != peekID {
|
||||
t.Fatalf("unexpected outbound peek ID: %s, want %s", outboundPeek1.PeekID, peekID)
|
||||
}
|
||||
if outboundPeek1.RoomID != room.ID {
|
||||
t.Fatalf("unexpected outbound peek room ID: %s, want %s", outboundPeek1.RoomID, peekID)
|
||||
}
|
||||
if outboundPeek1.ServerName != serverName {
|
||||
t.Fatalf("unexpected outbound peek servername: %s, want %s", outboundPeek1.ServerName, serverName)
|
||||
}
|
||||
if outboundPeek1.RenewalInterval != renewalInterval {
|
||||
t.Fatalf("unexpected outbound peek renewal interval: %d, want %d", outboundPeek1.RenewalInterval, renewalInterval)
|
||||
}
|
||||
// Renew the peek
|
||||
if err = db.RenewOutboundPeek(ctx, serverName, room.ID, peekID, 2000); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// verify the values changed
|
||||
outboundPeek2, err := db.GetOutboundPeek(ctx, serverName, room.ID, peekID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if reflect.DeepEqual(outboundPeek1, outboundPeek2) {
|
||||
t.Fatal("expected a change peek, but they are the same")
|
||||
}
|
||||
if outboundPeek1.ServerName != outboundPeek2.ServerName {
|
||||
t.Fatalf("unexpected servername change: %s -> %s", outboundPeek1.ServerName, outboundPeek2.ServerName)
|
||||
}
|
||||
if outboundPeek1.RoomID != outboundPeek2.RoomID {
|
||||
t.Fatalf("unexpected roomID change: %s -> %s", outboundPeek1.RoomID, outboundPeek2.RoomID)
|
||||
}
|
||||
|
||||
// insert some peeks
|
||||
peekIDs := []string{peekID}
|
||||
for i := 0; i < 5; i++ {
|
||||
peekID = util.RandomString(8)
|
||||
if err = db.AddOutboundPeek(ctx, serverName, room.ID, peekID, 1000); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
peekIDs = append(peekIDs, peekID)
|
||||
}
|
||||
|
||||
// Now select them
|
||||
outboundPeeks, err := db.GetOutboundPeeks(ctx, room.ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(outboundPeeks) != len(peekIDs) {
|
||||
t.Fatalf("inserted %d peeks, selected %d", len(peekIDs), len(outboundPeeks))
|
||||
}
|
||||
gotPeekIDs := make([]string, 0, len(outboundPeeks))
|
||||
for _, p := range outboundPeeks {
|
||||
gotPeekIDs = append(gotPeekIDs, p.PeekID)
|
||||
}
|
||||
assert.ElementsMatch(t, gotPeekIDs, peekIDs)
|
||||
})
|
||||
}
|
||||
|
||||
func TestInboundPeeking(t *testing.T) {
|
||||
alice := test.NewUser(t)
|
||||
room := test.NewRoom(t, alice)
|
||||
_, serverName, _ := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
ctx := context.Background()
|
||||
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, closeDB := mustCreateFederationDatabase(t, dbType)
|
||||
defer closeDB()
|
||||
peekID := util.RandomString(8)
|
||||
var renewalInterval int64 = 1000
|
||||
|
||||
// Add inbound peek
|
||||
if err := db.AddInboundPeek(ctx, serverName, room.ID, peekID, renewalInterval); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// select the newly inserted peek
|
||||
inboundPeek1, err := db.GetInboundPeek(ctx, serverName, room.ID, peekID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Assert fields are set as expected
|
||||
if inboundPeek1.PeekID != peekID {
|
||||
t.Fatalf("unexpected inbound peek ID: %s, want %s", inboundPeek1.PeekID, peekID)
|
||||
}
|
||||
if inboundPeek1.RoomID != room.ID {
|
||||
t.Fatalf("unexpected inbound peek room ID: %s, want %s", inboundPeek1.RoomID, peekID)
|
||||
}
|
||||
if inboundPeek1.ServerName != serverName {
|
||||
t.Fatalf("unexpected inbound peek servername: %s, want %s", inboundPeek1.ServerName, serverName)
|
||||
}
|
||||
if inboundPeek1.RenewalInterval != renewalInterval {
|
||||
t.Fatalf("unexpected inbound peek renewal interval: %d, want %d", inboundPeek1.RenewalInterval, renewalInterval)
|
||||
}
|
||||
// Renew the peek
|
||||
if err = db.RenewInboundPeek(ctx, serverName, room.ID, peekID, 2000); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// verify the values changed
|
||||
inboundPeek2, err := db.GetInboundPeek(ctx, serverName, room.ID, peekID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if reflect.DeepEqual(inboundPeek1, inboundPeek2) {
|
||||
t.Fatal("expected a change peek, but they are the same")
|
||||
}
|
||||
if inboundPeek1.ServerName != inboundPeek2.ServerName {
|
||||
t.Fatalf("unexpected servername change: %s -> %s", inboundPeek1.ServerName, inboundPeek2.ServerName)
|
||||
}
|
||||
if inboundPeek1.RoomID != inboundPeek2.RoomID {
|
||||
t.Fatalf("unexpected roomID change: %s -> %s", inboundPeek1.RoomID, inboundPeek2.RoomID)
|
||||
}
|
||||
|
||||
// insert some peeks
|
||||
peekIDs := []string{peekID}
|
||||
for i := 0; i < 5; i++ {
|
||||
peekID = util.RandomString(8)
|
||||
if err = db.AddInboundPeek(ctx, serverName, room.ID, peekID, 1000); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
peekIDs = append(peekIDs, peekID)
|
||||
}
|
||||
|
||||
// Now select them
|
||||
inboundPeeks, err := db.GetInboundPeeks(ctx, room.ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(inboundPeeks) != len(peekIDs) {
|
||||
t.Fatalf("inserted %d peeks, selected %d", len(peekIDs), len(inboundPeeks))
|
||||
}
|
||||
gotPeekIDs := make([]string, 0, len(inboundPeeks))
|
||||
for _, p := range inboundPeeks {
|
||||
gotPeekIDs = append(gotPeekIDs, p.PeekID)
|
||||
}
|
||||
assert.ElementsMatch(t, gotPeekIDs, peekIDs)
|
||||
})
|
||||
}
|
||||
|
|
|
|||
149
federationapi/storage/tables/inbound_peeks_table_test.go
Normal file
149
federationapi/storage/tables/inbound_peeks_table_test.go
Normal file
|
|
@ -0,0 +1,149 @@
|
|||
package tables_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/matrix-org/dendrite/federationapi/storage/postgres"
|
||||
"github.com/matrix-org/dendrite/federationapi/storage/sqlite3"
|
||||
"github.com/matrix-org/dendrite/federationapi/storage/tables"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/test"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func mustCreateInboundpeeksTable(t *testing.T, dbType test.DBType) (tables.FederationInboundPeeks, func()) {
|
||||
connStr, close := test.PrepareDBConnectionString(t, dbType)
|
||||
db, err := sqlutil.Open(&config.DatabaseOptions{
|
||||
ConnectionString: config.DataSource(connStr),
|
||||
}, sqlutil.NewExclusiveWriter())
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open database: %s", err)
|
||||
}
|
||||
var tab tables.FederationInboundPeeks
|
||||
switch dbType {
|
||||
case test.DBTypePostgres:
|
||||
tab, err = postgres.NewPostgresInboundPeeksTable(db)
|
||||
case test.DBTypeSQLite:
|
||||
tab, err = sqlite3.NewSQLiteInboundPeeksTable(db)
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create table: %s", err)
|
||||
}
|
||||
return tab, close
|
||||
}
|
||||
|
||||
func TestInboundPeeksTable(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
alice := test.NewUser(t)
|
||||
room := test.NewRoom(t, alice)
|
||||
_, serverName, _ := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
tab, closeDB := mustCreateInboundpeeksTable(t, dbType)
|
||||
defer closeDB()
|
||||
|
||||
// Insert a peek
|
||||
peekID := util.RandomString(8)
|
||||
var renewalInterval int64 = 1000
|
||||
if err := tab.InsertInboundPeek(ctx, nil, serverName, room.ID, peekID, renewalInterval); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// select the newly inserted peek
|
||||
inboundPeek1, err := tab.SelectInboundPeek(ctx, nil, serverName, room.ID, peekID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Assert fields are set as expected
|
||||
if inboundPeek1.PeekID != peekID {
|
||||
t.Fatalf("unexpected inbound peek ID: %s, want %s", inboundPeek1.PeekID, peekID)
|
||||
}
|
||||
if inboundPeek1.RoomID != room.ID {
|
||||
t.Fatalf("unexpected inbound peek room ID: %s, want %s", inboundPeek1.RoomID, peekID)
|
||||
}
|
||||
if inboundPeek1.ServerName != serverName {
|
||||
t.Fatalf("unexpected inbound peek servername: %s, want %s", inboundPeek1.ServerName, serverName)
|
||||
}
|
||||
if inboundPeek1.RenewalInterval != renewalInterval {
|
||||
t.Fatalf("unexpected inbound peek renewal interval: %d, want %d", inboundPeek1.RenewalInterval, renewalInterval)
|
||||
}
|
||||
|
||||
// Renew the peek
|
||||
if err = tab.RenewInboundPeek(ctx, nil, serverName, room.ID, peekID, 2000); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// verify the values changed
|
||||
inboundPeek2, err := tab.SelectInboundPeek(ctx, nil, serverName, room.ID, peekID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if reflect.DeepEqual(inboundPeek1, inboundPeek2) {
|
||||
t.Fatal("expected a change peek, but they are the same")
|
||||
}
|
||||
if inboundPeek1.ServerName != inboundPeek2.ServerName {
|
||||
t.Fatalf("unexpected servername change: %s -> %s", inboundPeek1.ServerName, inboundPeek2.ServerName)
|
||||
}
|
||||
if inboundPeek1.RoomID != inboundPeek2.RoomID {
|
||||
t.Fatalf("unexpected roomID change: %s -> %s", inboundPeek1.RoomID, inboundPeek2.RoomID)
|
||||
}
|
||||
|
||||
// delete the peek
|
||||
if err = tab.DeleteInboundPeek(ctx, nil, serverName, room.ID, peekID); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// There should be no peek anymore
|
||||
peek, err := tab.SelectInboundPeek(ctx, nil, serverName, room.ID, peekID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if peek != nil {
|
||||
t.Fatalf("got a peek which should be deleted: %+v", peek)
|
||||
}
|
||||
|
||||
// insert some peeks
|
||||
var peekIDs []string
|
||||
for i := 0; i < 5; i++ {
|
||||
peekID = util.RandomString(8)
|
||||
if err = tab.InsertInboundPeek(ctx, nil, serverName, room.ID, peekID, 1000); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
peekIDs = append(peekIDs, peekID)
|
||||
}
|
||||
|
||||
// Now select them
|
||||
inboundPeeks, err := tab.SelectInboundPeeks(ctx, nil, room.ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(inboundPeeks) != len(peekIDs) {
|
||||
t.Fatalf("inserted %d peeks, selected %d", len(peekIDs), len(inboundPeeks))
|
||||
}
|
||||
gotPeekIDs := make([]string, 0, len(inboundPeeks))
|
||||
for _, p := range inboundPeeks {
|
||||
gotPeekIDs = append(gotPeekIDs, p.PeekID)
|
||||
}
|
||||
assert.ElementsMatch(t, gotPeekIDs, peekIDs)
|
||||
|
||||
// And delete them again
|
||||
if err = tab.DeleteInboundPeeks(ctx, nil, room.ID); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// they should be gone now
|
||||
inboundPeeks, err = tab.SelectInboundPeeks(ctx, nil, room.ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(inboundPeeks) > 0 {
|
||||
t.Fatal("got inbound peeks which should be deleted")
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
|
|
@ -28,7 +28,6 @@ type FederationQueuePDUs interface {
|
|||
InsertQueuePDU(ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nid int64) error
|
||||
DeleteQueuePDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error
|
||||
SelectQueuePDUReferenceJSONCount(ctx context.Context, txn *sql.Tx, jsonNID int64) (int64, error)
|
||||
SelectQueuePDUCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error)
|
||||
SelectQueuePDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error)
|
||||
SelectQueuePDUServerNames(ctx context.Context, txn *sql.Tx) ([]gomatrixserverlib.ServerName, error)
|
||||
}
|
||||
|
|
@ -38,7 +37,6 @@ type FederationQueueEDUs interface {
|
|||
DeleteQueueEDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error
|
||||
SelectQueueEDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error)
|
||||
SelectQueueEDUReferenceJSONCount(ctx context.Context, txn *sql.Tx, jsonNID int64) (int64, error)
|
||||
SelectQueueEDUCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error)
|
||||
SelectQueueEDUServerNames(ctx context.Context, txn *sql.Tx) ([]gomatrixserverlib.ServerName, error)
|
||||
SelectExpiredEDUs(ctx context.Context, txn *sql.Tx, expiredBefore gomatrixserverlib.Timestamp) ([]int64, error)
|
||||
DeleteExpiredEDUs(ctx context.Context, txn *sql.Tx, expiredBefore gomatrixserverlib.Timestamp) error
|
||||
|
|
|
|||
148
federationapi/storage/tables/outbound_peeks_table_test.go
Normal file
148
federationapi/storage/tables/outbound_peeks_table_test.go
Normal file
|
|
@ -0,0 +1,148 @@
|
|||
package tables_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/matrix-org/dendrite/federationapi/storage/postgres"
|
||||
"github.com/matrix-org/dendrite/federationapi/storage/sqlite3"
|
||||
"github.com/matrix-org/dendrite/federationapi/storage/tables"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/test"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func mustCreateOutboundpeeksTable(t *testing.T, dbType test.DBType) (tables.FederationOutboundPeeks, func()) {
|
||||
connStr, close := test.PrepareDBConnectionString(t, dbType)
|
||||
db, err := sqlutil.Open(&config.DatabaseOptions{
|
||||
ConnectionString: config.DataSource(connStr),
|
||||
}, sqlutil.NewExclusiveWriter())
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open database: %s", err)
|
||||
}
|
||||
var tab tables.FederationOutboundPeeks
|
||||
switch dbType {
|
||||
case test.DBTypePostgres:
|
||||
tab, err = postgres.NewPostgresOutboundPeeksTable(db)
|
||||
case test.DBTypeSQLite:
|
||||
tab, err = sqlite3.NewSQLiteOutboundPeeksTable(db)
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create table: %s", err)
|
||||
}
|
||||
return tab, close
|
||||
}
|
||||
|
||||
func TestOutboundPeeksTable(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
alice := test.NewUser(t)
|
||||
room := test.NewRoom(t, alice)
|
||||
_, serverName, _ := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
tab, closeDB := mustCreateOutboundpeeksTable(t, dbType)
|
||||
defer closeDB()
|
||||
|
||||
// Insert a peek
|
||||
peekID := util.RandomString(8)
|
||||
var renewalInterval int64 = 1000
|
||||
if err := tab.InsertOutboundPeek(ctx, nil, serverName, room.ID, peekID, renewalInterval); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// select the newly inserted peek
|
||||
outboundPeek1, err := tab.SelectOutboundPeek(ctx, nil, serverName, room.ID, peekID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Assert fields are set as expected
|
||||
if outboundPeek1.PeekID != peekID {
|
||||
t.Fatalf("unexpected outbound peek ID: %s, want %s", outboundPeek1.PeekID, peekID)
|
||||
}
|
||||
if outboundPeek1.RoomID != room.ID {
|
||||
t.Fatalf("unexpected outbound peek room ID: %s, want %s", outboundPeek1.RoomID, peekID)
|
||||
}
|
||||
if outboundPeek1.ServerName != serverName {
|
||||
t.Fatalf("unexpected outbound peek servername: %s, want %s", outboundPeek1.ServerName, serverName)
|
||||
}
|
||||
if outboundPeek1.RenewalInterval != renewalInterval {
|
||||
t.Fatalf("unexpected outbound peek renewal interval: %d, want %d", outboundPeek1.RenewalInterval, renewalInterval)
|
||||
}
|
||||
|
||||
// Renew the peek
|
||||
if err = tab.RenewOutboundPeek(ctx, nil, serverName, room.ID, peekID, 2000); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// verify the values changed
|
||||
outboundPeek2, err := tab.SelectOutboundPeek(ctx, nil, serverName, room.ID, peekID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if reflect.DeepEqual(outboundPeek1, outboundPeek2) {
|
||||
t.Fatal("expected a change peek, but they are the same")
|
||||
}
|
||||
if outboundPeek1.ServerName != outboundPeek2.ServerName {
|
||||
t.Fatalf("unexpected servername change: %s -> %s", outboundPeek1.ServerName, outboundPeek2.ServerName)
|
||||
}
|
||||
if outboundPeek1.RoomID != outboundPeek2.RoomID {
|
||||
t.Fatalf("unexpected roomID change: %s -> %s", outboundPeek1.RoomID, outboundPeek2.RoomID)
|
||||
}
|
||||
|
||||
// delete the peek
|
||||
if err = tab.DeleteOutboundPeek(ctx, nil, serverName, room.ID, peekID); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// There should be no peek anymore
|
||||
peek, err := tab.SelectOutboundPeek(ctx, nil, serverName, room.ID, peekID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if peek != nil {
|
||||
t.Fatalf("got a peek which should be deleted: %+v", peek)
|
||||
}
|
||||
|
||||
// insert some peeks
|
||||
var peekIDs []string
|
||||
for i := 0; i < 5; i++ {
|
||||
peekID = util.RandomString(8)
|
||||
if err = tab.InsertOutboundPeek(ctx, nil, serverName, room.ID, peekID, 1000); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
peekIDs = append(peekIDs, peekID)
|
||||
}
|
||||
|
||||
// Now select them
|
||||
outboundPeeks, err := tab.SelectOutboundPeeks(ctx, nil, room.ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(outboundPeeks) != len(peekIDs) {
|
||||
t.Fatalf("inserted %d peeks, selected %d", len(peekIDs), len(outboundPeeks))
|
||||
}
|
||||
gotPeekIDs := make([]string, 0, len(outboundPeeks))
|
||||
for _, p := range outboundPeeks {
|
||||
gotPeekIDs = append(gotPeekIDs, p.PeekID)
|
||||
}
|
||||
assert.ElementsMatch(t, gotPeekIDs, peekIDs)
|
||||
|
||||
// And delete them again
|
||||
if err = tab.DeleteOutboundPeeks(ctx, nil, room.ID); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// they should be gone now
|
||||
outboundPeeks, err = tab.SelectOutboundPeeks(ctx, nil, room.ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(outboundPeeks) > 0 {
|
||||
t.Fatal("got outbound peeks which should be deleted")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
@ -198,7 +198,7 @@ func MakeExternalAPI(metricsName string, f func(*http.Request) util.JSONResponse
|
|||
|
||||
// MakeHTMLAPI adds Span metrics to the HTML Handler function
|
||||
// This is used to serve HTML alongside JSON error messages
|
||||
func MakeHTMLAPI(metricsName string, f func(http.ResponseWriter, *http.Request) *util.JSONResponse) http.Handler {
|
||||
func MakeHTMLAPI(metricsName string, enableMetrics bool, f func(http.ResponseWriter, *http.Request) *util.JSONResponse) http.Handler {
|
||||
withSpan := func(w http.ResponseWriter, req *http.Request) {
|
||||
span := opentracing.StartSpan(metricsName)
|
||||
defer span.Finish()
|
||||
|
|
@ -211,6 +211,10 @@ func MakeHTMLAPI(metricsName string, f func(http.ResponseWriter, *http.Request)
|
|||
}
|
||||
}
|
||||
|
||||
if !enableMetrics {
|
||||
return http.HandlerFunc(withSpan)
|
||||
}
|
||||
|
||||
return promhttp.InstrumentHandlerCounter(
|
||||
promauto.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
|
|
|
|||
|
|
@ -33,6 +33,11 @@ import (
|
|||
"github.com/matrix-org/dendrite/setup/config"
|
||||
)
|
||||
|
||||
// logrus is using a global variable when we're using `logrus.AddHook`
|
||||
// this unfortunately results in us adding the same hook multiple times.
|
||||
// This map ensures we only ever add one level hook.
|
||||
var stdLevelLogAdded = make(map[logrus.Level]bool)
|
||||
|
||||
type utcFormatter struct {
|
||||
logrus.Formatter
|
||||
}
|
||||
|
|
|
|||
|
|
@ -22,16 +22,16 @@ import (
|
|||
"log/syslog"
|
||||
|
||||
"github.com/MFAshby/stdemuxerhook"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/sirupsen/logrus"
|
||||
lSyslog "github.com/sirupsen/logrus/hooks/syslog"
|
||||
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
)
|
||||
|
||||
// SetupHookLogging configures the logging hooks defined in the configuration.
|
||||
// If something fails here it means that the logging was improperly configured,
|
||||
// so we just exit with the error
|
||||
func SetupHookLogging(hooks []config.LogrusHook, componentName string) {
|
||||
stdLogAdded := false
|
||||
for _, hook := range hooks {
|
||||
// Check we received a proper logging level
|
||||
level, err := logrus.ParseLevel(hook.Level)
|
||||
|
|
@ -54,14 +54,11 @@ func SetupHookLogging(hooks []config.LogrusHook, componentName string) {
|
|||
setupSyslogHook(hook, level, componentName)
|
||||
case "std":
|
||||
setupStdLogHook(level)
|
||||
stdLogAdded = true
|
||||
default:
|
||||
logrus.Fatalf("Unrecognised logging hook type: %s", hook.Type)
|
||||
}
|
||||
}
|
||||
if !stdLogAdded {
|
||||
setupStdLogHook(logrus.InfoLevel)
|
||||
}
|
||||
setupStdLogHook(logrus.InfoLevel)
|
||||
// Hooks are now configured for stdout/err, so throw away the default logger output
|
||||
logrus.SetOutput(io.Discard)
|
||||
}
|
||||
|
|
@ -88,7 +85,11 @@ func checkSyslogHookParams(params map[string]interface{}) {
|
|||
}
|
||||
|
||||
func setupStdLogHook(level logrus.Level) {
|
||||
if stdLevelLogAdded[level] {
|
||||
return
|
||||
}
|
||||
logrus.AddHook(&logLevelHook{level, stdemuxerhook.New(logrus.StandardLogger())})
|
||||
stdLevelLogAdded[level] = true
|
||||
}
|
||||
|
||||
func setupSyslogHook(hook config.LogrusHook, level logrus.Level, componentName string) {
|
||||
|
|
|
|||
44
internal/validate.go
Normal file
44
internal/validate.go
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||
"github.com/matrix-org/util"
|
||||
)
|
||||
|
||||
const minPasswordLength = 8 // http://matrix.org/docs/spec/client_server/r0.2.0.html#password-based
|
||||
|
||||
const maxPasswordLength = 512 // https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161
|
||||
|
||||
// ValidatePassword returns an error response if the password is invalid
|
||||
func ValidatePassword(password string) *util.JSONResponse {
|
||||
// https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161
|
||||
if len(password) > maxPasswordLength {
|
||||
return &util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.BadJSON(fmt.Sprintf("password too long: max %d characters", maxPasswordLength)),
|
||||
}
|
||||
} else if len(password) > 0 && len(password) < minPasswordLength {
|
||||
return &util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.WeakPassword(fmt.Sprintf("password too weak: min %d chars", minPasswordLength)),
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
@ -24,6 +24,8 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
rsapi "github.com/matrix-org/dendrite/roomserver/api"
|
||||
|
||||
"github.com/matrix-org/gomatrix"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
|
|
@ -102,6 +104,7 @@ type DeviceListUpdater struct {
|
|||
// block on or timeout via a select.
|
||||
userIDToChan map[string]chan bool
|
||||
userIDToChanMu *sync.Mutex
|
||||
rsAPI rsapi.KeyserverRoomserverAPI
|
||||
}
|
||||
|
||||
// DeviceListUpdaterDatabase is the subset of functionality from storage.Database required for the updater.
|
||||
|
|
@ -124,6 +127,8 @@ type DeviceListUpdaterDatabase interface {
|
|||
|
||||
// DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced.
|
||||
DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
|
||||
|
||||
DeleteStaleDeviceLists(ctx context.Context, userIDs []string) error
|
||||
}
|
||||
|
||||
type DeviceListUpdaterAPI interface {
|
||||
|
|
@ -140,7 +145,7 @@ func NewDeviceListUpdater(
|
|||
process *process.ProcessContext, db DeviceListUpdaterDatabase,
|
||||
api DeviceListUpdaterAPI, producer KeyChangeProducer,
|
||||
fedClient fedsenderapi.KeyserverFederationAPI, numWorkers int,
|
||||
thisServer gomatrixserverlib.ServerName,
|
||||
rsAPI rsapi.KeyserverRoomserverAPI, thisServer gomatrixserverlib.ServerName,
|
||||
) *DeviceListUpdater {
|
||||
return &DeviceListUpdater{
|
||||
process: process,
|
||||
|
|
@ -154,6 +159,7 @@ func NewDeviceListUpdater(
|
|||
workerChans: make([]chan gomatrixserverlib.ServerName, numWorkers),
|
||||
userIDToChan: make(map[string]chan bool),
|
||||
userIDToChanMu: &sync.Mutex{},
|
||||
rsAPI: rsAPI,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -168,7 +174,7 @@ func (u *DeviceListUpdater) Start() error {
|
|||
go u.worker(ch)
|
||||
}
|
||||
|
||||
staleLists, err := u.db.StaleDeviceLists(context.Background(), []gomatrixserverlib.ServerName{})
|
||||
staleLists, err := u.db.StaleDeviceLists(u.process.Context(), []gomatrixserverlib.ServerName{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -186,6 +192,25 @@ func (u *DeviceListUpdater) Start() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// CleanUp removes stale device entries for users we don't share a room with anymore
|
||||
func (u *DeviceListUpdater) CleanUp() error {
|
||||
staleUsers, err := u.db.StaleDeviceLists(u.process.Context(), []gomatrixserverlib.ServerName{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
res := rsapi.QueryLeftUsersResponse{}
|
||||
if err = u.rsAPI.QueryLeftUsers(u.process.Context(), &rsapi.QueryLeftUsersRequest{StaleDeviceListUsers: staleUsers}, &res); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(res.LeftUsers) == 0 {
|
||||
return nil
|
||||
}
|
||||
logrus.Debugf("Deleting %d stale device list entries", len(res.LeftUsers))
|
||||
return u.db.DeleteStaleDeviceLists(u.process.Context(), res.LeftUsers)
|
||||
}
|
||||
|
||||
func (u *DeviceListUpdater) mutex(userID string) *sync.Mutex {
|
||||
u.mu.Lock()
|
||||
defer u.mu.Unlock()
|
||||
|
|
|
|||
|
|
@ -30,7 +30,12 @@ import (
|
|||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
||||
"github.com/matrix-org/dendrite/keyserver/api"
|
||||
"github.com/matrix-org/dendrite/keyserver/storage"
|
||||
roomserver "github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/setup/process"
|
||||
"github.com/matrix-org/dendrite/test"
|
||||
"github.com/matrix-org/dendrite/test/testrig"
|
||||
)
|
||||
|
||||
var (
|
||||
|
|
@ -53,6 +58,10 @@ type mockDeviceListUpdaterDatabase struct {
|
|||
mu sync.Mutex // protect staleUsers
|
||||
}
|
||||
|
||||
func (d *mockDeviceListUpdaterDatabase) DeleteStaleDeviceLists(ctx context.Context, userIDs []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists.
|
||||
// If no domains are given, all user IDs with stale device lists are returned.
|
||||
func (d *mockDeviceListUpdaterDatabase) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
|
||||
|
|
@ -153,7 +162,7 @@ func TestUpdateHavePrevID(t *testing.T) {
|
|||
}
|
||||
ap := &mockDeviceListUpdaterAPI{}
|
||||
producer := &mockKeyChangeProducer{}
|
||||
updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, nil, 1, "localhost")
|
||||
updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, nil, 1, nil, "localhost")
|
||||
event := gomatrixserverlib.DeviceListUpdateEvent{
|
||||
DeviceDisplayName: "Foo Bar",
|
||||
Deleted: false,
|
||||
|
|
@ -225,7 +234,7 @@ func TestUpdateNoPrevID(t *testing.T) {
|
|||
`)),
|
||||
}, nil
|
||||
})
|
||||
updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 2, "example.test")
|
||||
updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 2, nil, "example.test")
|
||||
if err := updater.Start(); err != nil {
|
||||
t.Fatalf("failed to start updater: %s", err)
|
||||
}
|
||||
|
|
@ -239,6 +248,7 @@ func TestUpdateNoPrevID(t *testing.T) {
|
|||
UserID: remoteUserID,
|
||||
}
|
||||
err := updater.Update(ctx, event)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Update returned an error: %s", err)
|
||||
}
|
||||
|
|
@ -294,7 +304,7 @@ func TestDebounce(t *testing.T) {
|
|||
close(incomingFedReq)
|
||||
return <-fedCh, nil
|
||||
})
|
||||
updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 1, "localhost")
|
||||
updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 1, nil, "localhost")
|
||||
if err := updater.Start(); err != nil {
|
||||
t.Fatalf("failed to start updater: %s", err)
|
||||
}
|
||||
|
|
@ -349,3 +359,73 @@ func TestDebounce(t *testing.T) {
|
|||
t.Errorf("user %s is marked as stale", userID)
|
||||
}
|
||||
}
|
||||
|
||||
func mustCreateKeyserverDB(t *testing.T, dbType test.DBType) (storage.Database, func()) {
|
||||
t.Helper()
|
||||
|
||||
base, _, _ := testrig.Base(nil)
|
||||
connStr, clearDB := test.PrepareDBConnectionString(t, dbType)
|
||||
db, err := storage.NewDatabase(base, &config.DatabaseOptions{ConnectionString: config.DataSource(connStr)})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return db, clearDB
|
||||
}
|
||||
|
||||
type mockKeyserverRoomserverAPI struct {
|
||||
leftUsers []string
|
||||
}
|
||||
|
||||
func (m *mockKeyserverRoomserverAPI) QueryLeftUsers(ctx context.Context, req *roomserver.QueryLeftUsersRequest, res *roomserver.QueryLeftUsersResponse) error {
|
||||
res.LeftUsers = m.leftUsers
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestDeviceListUpdater_CleanUp(t *testing.T) {
|
||||
processCtx := process.NewProcessContext()
|
||||
|
||||
alice := test.NewUser(t)
|
||||
bob := test.NewUser(t)
|
||||
|
||||
// Bob is not joined to any of our rooms
|
||||
rsAPI := &mockKeyserverRoomserverAPI{leftUsers: []string{bob.ID}}
|
||||
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, clearDB := mustCreateKeyserverDB(t, dbType)
|
||||
defer clearDB()
|
||||
|
||||
// This should not get deleted
|
||||
if err := db.MarkDeviceListStale(processCtx.Context(), alice.ID, true); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
// this one should get deleted
|
||||
if err := db.MarkDeviceListStale(processCtx.Context(), bob.ID, true); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
updater := NewDeviceListUpdater(processCtx, db, nil,
|
||||
nil, nil,
|
||||
0, rsAPI, "test")
|
||||
if err := updater.CleanUp(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
// check that we still have Alice in our stale list
|
||||
staleUsers, err := db.StaleDeviceLists(ctx, []gomatrixserverlib.ServerName{"test"})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
// There should only be Alice
|
||||
wantCount := 1
|
||||
if count := len(staleUsers); count != wantCount {
|
||||
t.Fatalf("expected there to be %d stale device lists, got %d", wantCount, count)
|
||||
}
|
||||
|
||||
if staleUsers[0] != alice.ID {
|
||||
t.Fatalf("unexpected stale device list user: %s, want %s", staleUsers[0], alice.ID)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,6 +18,8 @@ import (
|
|||
"github.com/gorilla/mux"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
rsapi "github.com/matrix-org/dendrite/roomserver/api"
|
||||
|
||||
fedsenderapi "github.com/matrix-org/dendrite/federationapi/api"
|
||||
"github.com/matrix-org/dendrite/keyserver/api"
|
||||
"github.com/matrix-org/dendrite/keyserver/consumers"
|
||||
|
|
@ -40,6 +42,7 @@ func AddInternalRoutes(router *mux.Router, intAPI api.KeyInternalAPI, enableMetr
|
|||
// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes.
|
||||
func NewInternalAPI(
|
||||
base *base.BaseDendrite, cfg *config.KeyServer, fedClient fedsenderapi.KeyserverFederationAPI,
|
||||
rsAPI rsapi.KeyserverRoomserverAPI,
|
||||
) api.KeyInternalAPI {
|
||||
js, _ := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream)
|
||||
|
||||
|
|
@ -47,6 +50,7 @@ func NewInternalAPI(
|
|||
if err != nil {
|
||||
logrus.WithError(err).Panicf("failed to connect to key server database")
|
||||
}
|
||||
|
||||
keyChangeProducer := &producers.KeyChange{
|
||||
Topic: string(cfg.Matrix.JetStream.Prefixed(jetstream.OutputKeyChangeEvent)),
|
||||
JetStream: js,
|
||||
|
|
@ -58,8 +62,14 @@ func NewInternalAPI(
|
|||
FedClient: fedClient,
|
||||
Producer: keyChangeProducer,
|
||||
}
|
||||
updater := internal.NewDeviceListUpdater(base.ProcessContext, db, ap, keyChangeProducer, fedClient, 8, cfg.Matrix.ServerName) // 8 workers TODO: configurable
|
||||
updater := internal.NewDeviceListUpdater(base.ProcessContext, db, ap, keyChangeProducer, fedClient, 8, rsAPI, cfg.Matrix.ServerName) // 8 workers TODO: configurable
|
||||
ap.Updater = updater
|
||||
|
||||
// Remove users which we don't share a room with anymore
|
||||
if err := updater.CleanUp(); err != nil {
|
||||
logrus.WithError(err).Error("failed to cleanup stale device lists")
|
||||
}
|
||||
|
||||
go func() {
|
||||
if err := updater.Start(); err != nil {
|
||||
logrus.WithError(err).Panicf("failed to start device list updater")
|
||||
|
|
|
|||
29
keyserver/keyserver_test.go
Normal file
29
keyserver/keyserver_test.go
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
package keyserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
roomserver "github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/dendrite/test"
|
||||
"github.com/matrix-org/dendrite/test/testrig"
|
||||
)
|
||||
|
||||
type mockKeyserverRoomserverAPI struct {
|
||||
leftUsers []string
|
||||
}
|
||||
|
||||
func (m *mockKeyserverRoomserverAPI) QueryLeftUsers(ctx context.Context, req *roomserver.QueryLeftUsersRequest, res *roomserver.QueryLeftUsersResponse) error {
|
||||
res.LeftUsers = m.leftUsers
|
||||
return nil
|
||||
}
|
||||
|
||||
// Merely tests that we can create an internal keyserver API
|
||||
func Test_NewInternalAPI(t *testing.T) {
|
||||
rsAPI := &mockKeyserverRoomserverAPI{}
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
base, closeBase := testrig.CreateBaseDendrite(t, dbType)
|
||||
defer closeBase()
|
||||
_ = NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI)
|
||||
})
|
||||
}
|
||||
|
|
@ -85,4 +85,9 @@ type Database interface {
|
|||
|
||||
StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error
|
||||
StoreCrossSigningSigsForTarget(ctx context.Context, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error
|
||||
|
||||
DeleteStaleDeviceLists(
|
||||
ctx context.Context,
|
||||
userIDs []string,
|
||||
) error
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,6 +19,10 @@ import (
|
|||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/lib/pq"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
|
@ -48,10 +52,14 @@ const selectStaleDeviceListsWithDomainsSQL = "" +
|
|||
const selectStaleDeviceListsSQL = "" +
|
||||
"SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 ORDER BY ts_added_secs DESC"
|
||||
|
||||
const deleteStaleDevicesSQL = "" +
|
||||
"DELETE FROM keyserver_stale_device_lists WHERE user_id = ANY($1)"
|
||||
|
||||
type staleDeviceListsStatements struct {
|
||||
upsertStaleDeviceListStmt *sql.Stmt
|
||||
selectStaleDeviceListsWithDomainsStmt *sql.Stmt
|
||||
selectStaleDeviceListsStmt *sql.Stmt
|
||||
deleteStaleDeviceListsStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewPostgresStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) {
|
||||
|
|
@ -60,16 +68,12 @@ func NewPostgresStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, erro
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.upsertStaleDeviceListStmt, err = db.Prepare(upsertStaleDeviceListSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectStaleDeviceListsStmt, err = db.Prepare(selectStaleDeviceListsSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectStaleDeviceListsWithDomainsStmt, err = db.Prepare(selectStaleDeviceListsWithDomainsSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.upsertStaleDeviceListStmt, upsertStaleDeviceListSQL},
|
||||
{&s.selectStaleDeviceListsStmt, selectStaleDeviceListsSQL},
|
||||
{&s.selectStaleDeviceListsWithDomainsStmt, selectStaleDeviceListsWithDomainsSQL},
|
||||
{&s.deleteStaleDeviceListsStmt, deleteStaleDevicesSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error {
|
||||
|
|
@ -105,6 +109,15 @@ func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx conte
|
|||
return result, nil
|
||||
}
|
||||
|
||||
// DeleteStaleDeviceLists removes users from stale device lists
|
||||
func (s *staleDeviceListsStatements) DeleteStaleDeviceLists(
|
||||
ctx context.Context, txn *sql.Tx, userIDs []string,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteStaleDeviceListsStmt)
|
||||
_, err := stmt.ExecContext(ctx, pq.Array(userIDs))
|
||||
return err
|
||||
}
|
||||
|
||||
func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) {
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed")
|
||||
for rows.Next() {
|
||||
|
|
|
|||
|
|
@ -249,3 +249,13 @@ func (d *Database) StoreCrossSigningSigsForTarget(
|
|||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// DeleteStaleDeviceLists deletes stale device list entries for users we don't share a room with anymore.
|
||||
func (d *Database) DeleteStaleDeviceLists(
|
||||
ctx context.Context,
|
||||
userIDs []string,
|
||||
) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
return d.StaleDeviceListsTable.DeleteStaleDeviceLists(ctx, txn, userIDs)
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,8 +17,11 @@ package sqlite3
|
|||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
|
@ -48,11 +51,15 @@ const selectStaleDeviceListsWithDomainsSQL = "" +
|
|||
const selectStaleDeviceListsSQL = "" +
|
||||
"SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 ORDER BY ts_added_secs DESC"
|
||||
|
||||
const deleteStaleDevicesSQL = "" +
|
||||
"DELETE FROM keyserver_stale_device_lists WHERE user_id IN ($1)"
|
||||
|
||||
type staleDeviceListsStatements struct {
|
||||
db *sql.DB
|
||||
upsertStaleDeviceListStmt *sql.Stmt
|
||||
selectStaleDeviceListsWithDomainsStmt *sql.Stmt
|
||||
selectStaleDeviceListsStmt *sql.Stmt
|
||||
// deleteStaleDeviceListsStmt *sql.Stmt // Prepared at runtime
|
||||
}
|
||||
|
||||
func NewSqliteStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) {
|
||||
|
|
@ -63,16 +70,12 @@ func NewSqliteStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error)
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.upsertStaleDeviceListStmt, err = db.Prepare(upsertStaleDeviceListSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectStaleDeviceListsStmt, err = db.Prepare(selectStaleDeviceListsSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectStaleDeviceListsWithDomainsStmt, err = db.Prepare(selectStaleDeviceListsWithDomainsSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.upsertStaleDeviceListStmt, upsertStaleDeviceListSQL},
|
||||
{&s.selectStaleDeviceListsStmt, selectStaleDeviceListsSQL},
|
||||
{&s.selectStaleDeviceListsWithDomainsStmt, selectStaleDeviceListsWithDomainsSQL},
|
||||
// { &s.deleteStaleDeviceListsStmt, deleteStaleDevicesSQL}, // Prepared at runtime
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error {
|
||||
|
|
@ -108,6 +111,27 @@ func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx conte
|
|||
return result, nil
|
||||
}
|
||||
|
||||
// DeleteStaleDeviceLists removes users from stale device lists
|
||||
func (s *staleDeviceListsStatements) DeleteStaleDeviceLists(
|
||||
ctx context.Context, txn *sql.Tx, userIDs []string,
|
||||
) error {
|
||||
qry := strings.Replace(deleteStaleDevicesSQL, "($1)", sqlutil.QueryVariadic(len(userIDs)), 1)
|
||||
stmt, err := s.db.Prepare(qry)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, stmt, "DeleteStaleDeviceLists: stmt.Close failed")
|
||||
stmt = sqlutil.TxStmt(txn, stmt)
|
||||
|
||||
params := make([]any, len(userIDs))
|
||||
for i := range userIDs {
|
||||
params[i] = userIDs[i]
|
||||
}
|
||||
|
||||
_, err = stmt.ExecContext(ctx, params...)
|
||||
return err
|
||||
}
|
||||
|
||||
func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) {
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed")
|
||||
for rows.Next() {
|
||||
|
|
|
|||
|
|
@ -56,6 +56,7 @@ type KeyChanges interface {
|
|||
type StaleDeviceLists interface {
|
||||
InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error
|
||||
SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error)
|
||||
DeleteStaleDeviceLists(ctx context.Context, txn *sql.Tx, userIDs []string) error
|
||||
}
|
||||
|
||||
type CrossSigningKeys interface {
|
||||
|
|
|
|||
94
keyserver/storage/tables/stale_device_lists_test.go
Normal file
94
keyserver/storage/tables/stale_device_lists_test.go
Normal file
|
|
@ -0,0 +1,94 @@
|
|||
package tables_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/keyserver/storage/sqlite3"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
|
||||
"github.com/matrix-org/dendrite/keyserver/storage/postgres"
|
||||
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
||||
"github.com/matrix-org/dendrite/test"
|
||||
)
|
||||
|
||||
func mustCreateTable(t *testing.T, dbType test.DBType) (tab tables.StaleDeviceLists, close func()) {
|
||||
connStr, close := test.PrepareDBConnectionString(t, dbType)
|
||||
db, err := sqlutil.Open(&config.DatabaseOptions{
|
||||
ConnectionString: config.DataSource(connStr),
|
||||
}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open database: %s", err)
|
||||
}
|
||||
switch dbType {
|
||||
case test.DBTypePostgres:
|
||||
tab, err = postgres.NewPostgresStaleDeviceListsTable(db)
|
||||
case test.DBTypeSQLite:
|
||||
tab, err = sqlite3.NewSqliteStaleDeviceListsTable(db)
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create new table: %s", err)
|
||||
}
|
||||
return tab, close
|
||||
}
|
||||
|
||||
func TestStaleDeviceLists(t *testing.T) {
|
||||
alice := test.NewUser(t)
|
||||
bob := test.NewUser(t)
|
||||
charlie := "@charlie:localhost"
|
||||
ctx := context.Background()
|
||||
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
tab, closeDB := mustCreateTable(t, dbType)
|
||||
defer closeDB()
|
||||
|
||||
if err := tab.InsertStaleDeviceList(ctx, alice.ID, true); err != nil {
|
||||
t.Fatalf("failed to insert stale device: %s", err)
|
||||
}
|
||||
if err := tab.InsertStaleDeviceList(ctx, bob.ID, true); err != nil {
|
||||
t.Fatalf("failed to insert stale device: %s", err)
|
||||
}
|
||||
if err := tab.InsertStaleDeviceList(ctx, charlie, true); err != nil {
|
||||
t.Fatalf("failed to insert stale device: %s", err)
|
||||
}
|
||||
|
||||
// Query one server
|
||||
wantStaleUsers := []string{alice.ID, bob.ID}
|
||||
gotStaleUsers, err := tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{"test"})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to query stale device lists: %s", err)
|
||||
}
|
||||
if !test.UnsortedStringSliceEqual(wantStaleUsers, gotStaleUsers) {
|
||||
t.Fatalf("expected stale users %v, got %v", wantStaleUsers, gotStaleUsers)
|
||||
}
|
||||
|
||||
// Query all servers
|
||||
wantStaleUsers = []string{alice.ID, bob.ID, charlie}
|
||||
gotStaleUsers, err = tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to query stale device lists: %s", err)
|
||||
}
|
||||
if !test.UnsortedStringSliceEqual(wantStaleUsers, gotStaleUsers) {
|
||||
t.Fatalf("expected stale users %v, got %v", wantStaleUsers, gotStaleUsers)
|
||||
}
|
||||
|
||||
// Delete stale devices
|
||||
deleteUsers := []string{alice.ID, bob.ID}
|
||||
if err = tab.DeleteStaleDeviceLists(ctx, nil, deleteUsers); err != nil {
|
||||
t.Fatalf("failed to delete stale device lists: %s", err)
|
||||
}
|
||||
|
||||
// Verify we don't get anything back after deleting
|
||||
gotStaleUsers, err = tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{"test"})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to query stale device lists: %s", err)
|
||||
}
|
||||
|
||||
if gotCount := len(gotStaleUsers); gotCount > 0 {
|
||||
t.Fatalf("expected no stale users, got %d", gotCount)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
@ -17,6 +17,7 @@ type RoomserverInternalAPI interface {
|
|||
ClientRoomserverAPI
|
||||
UserRoomserverAPI
|
||||
FederationRoomserverAPI
|
||||
KeyserverRoomserverAPI
|
||||
|
||||
// needed to avoid chicken and egg scenario when setting up the
|
||||
// interdependencies between the roomserver and other input APIs
|
||||
|
|
@ -199,3 +200,7 @@ type FederationRoomserverAPI interface {
|
|||
// Query a given amount (or less) of events prior to a given set of events.
|
||||
PerformBackfill(ctx context.Context, req *PerformBackfillRequest, res *PerformBackfillResponse) error
|
||||
}
|
||||
|
||||
type KeyserverRoomserverAPI interface {
|
||||
QueryLeftUsers(ctx context.Context, req *QueryLeftUsersRequest, res *QueryLeftUsersResponse) error
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,6 +19,12 @@ type RoomserverInternalAPITrace struct {
|
|||
Impl RoomserverInternalAPI
|
||||
}
|
||||
|
||||
func (t *RoomserverInternalAPITrace) QueryLeftUsers(ctx context.Context, req *QueryLeftUsersRequest, res *QueryLeftUsersResponse) error {
|
||||
err := t.Impl.QueryLeftUsers(ctx, req, res)
|
||||
util.GetLogger(ctx).WithError(err).Infof("QueryLeftUsers req=%+v res=%+v", js(req), js(res))
|
||||
return err
|
||||
}
|
||||
|
||||
func (t *RoomserverInternalAPITrace) SetFederationAPI(fsAPI fsAPI.RoomserverFederationAPI, keyRing *gomatrixserverlib.KeyRing) {
|
||||
t.Impl.SetFederationAPI(fsAPI, keyRing)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -78,6 +78,7 @@ const (
|
|||
type PerformJoinRequest struct {
|
||||
RoomIDOrAlias string `json:"room_id_or_alias"`
|
||||
UserID string `json:"user_id"`
|
||||
IsGuest bool `json:"is_guest"`
|
||||
Content map[string]interface{} `json:"content"`
|
||||
ServerNames []gomatrixserverlib.ServerName `json:"server_names"`
|
||||
Unsigned map[string]interface{} `json:"unsigned"`
|
||||
|
|
|
|||
|
|
@ -447,3 +447,15 @@ type QueryMembershipAtEventResponse struct {
|
|||
// do not have known state will return an empty array here.
|
||||
Memberships map[string][]*gomatrixserverlib.HeaderedEvent `json:"memberships"`
|
||||
}
|
||||
|
||||
// QueryLeftUsersRequest is a request to calculate users that we (the server) don't share a
|
||||
// a room with anymore. This is used to cleanup stale device list entries, where we would
|
||||
// otherwise keep on trying to get device lists.
|
||||
type QueryLeftUsersRequest struct {
|
||||
StaleDeviceListUsers []string `json:"user_ids"`
|
||||
}
|
||||
|
||||
// QueryLeftUsersResponse is the response to QueryLeftUsersRequest.
|
||||
type QueryLeftUsersResponse struct {
|
||||
LeftUsers []string `json:"user_ids"`
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,6 +4,10 @@ import (
|
|||
"context"
|
||||
|
||||
"github.com/getsentry/sentry-go"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/nats-io/nats.go"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
asAPI "github.com/matrix-org/dendrite/appservice/api"
|
||||
fsAPI "github.com/matrix-org/dendrite/federationapi/api"
|
||||
"github.com/matrix-org/dendrite/internal/caching"
|
||||
|
|
@ -19,9 +23,6 @@ import (
|
|||
"github.com/matrix-org/dendrite/setup/jetstream"
|
||||
"github.com/matrix-org/dendrite/setup/process"
|
||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/nats-io/nats.go"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// RoomserverInternalAPI is an implementation of api.RoomserverInternalAPI
|
||||
|
|
@ -104,6 +105,11 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio
|
|||
r.fsAPI = fsAPI
|
||||
r.KeyRing = keyRing
|
||||
|
||||
identity, err := r.Cfg.Matrix.SigningIdentityFor(r.ServerName)
|
||||
if err != nil {
|
||||
logrus.Panic(err)
|
||||
}
|
||||
|
||||
r.Inputer = &input.Inputer{
|
||||
Cfg: &r.Base.Cfg.RoomServer,
|
||||
Base: r.Base,
|
||||
|
|
@ -114,7 +120,8 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio
|
|||
JetStream: r.JetStream,
|
||||
NATSClient: r.NATSClient,
|
||||
Durable: nats.Durable(r.Durable),
|
||||
ServerName: r.Cfg.Matrix.ServerName,
|
||||
ServerName: r.ServerName,
|
||||
SigningIdentity: identity,
|
||||
FSAPI: fsAPI,
|
||||
KeyRing: keyRing,
|
||||
ACLs: r.ServerACLs,
|
||||
|
|
@ -135,7 +142,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio
|
|||
Queryer: r.Queryer,
|
||||
}
|
||||
r.Peeker = &perform.Peeker{
|
||||
ServerName: r.Cfg.Matrix.ServerName,
|
||||
ServerName: r.ServerName,
|
||||
Cfg: r.Cfg,
|
||||
DB: r.DB,
|
||||
FSAPI: r.fsAPI,
|
||||
|
|
@ -146,7 +153,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio
|
|||
Inputer: r.Inputer,
|
||||
}
|
||||
r.Unpeeker = &perform.Unpeeker{
|
||||
ServerName: r.Cfg.Matrix.ServerName,
|
||||
ServerName: r.ServerName,
|
||||
Cfg: r.Cfg,
|
||||
DB: r.DB,
|
||||
FSAPI: r.fsAPI,
|
||||
|
|
@ -193,6 +200,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio
|
|||
|
||||
func (r *RoomserverInternalAPI) SetUserAPI(userAPI userapi.RoomserverUserAPI) {
|
||||
r.Leaver.UserAPI = userAPI
|
||||
r.Inputer.UserAPI = userAPI
|
||||
}
|
||||
|
||||
func (r *RoomserverInternalAPI) SetAppserviceAPI(asAPI asAPI.AppServiceInternalAPI) {
|
||||
|
|
|
|||
|
|
@ -23,6 +23,8 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||
|
||||
"github.com/Arceliar/phony"
|
||||
"github.com/getsentry/sentry-go"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
|
@ -79,6 +81,7 @@ type Inputer struct {
|
|||
JetStream nats.JetStreamContext
|
||||
Durable nats.SubOpt
|
||||
ServerName gomatrixserverlib.ServerName
|
||||
SigningIdentity *gomatrixserverlib.SigningIdentity
|
||||
FSAPI fedapi.RoomserverFederationAPI
|
||||
KeyRing gomatrixserverlib.JSONVerifier
|
||||
ACLs *acls.ServerACLs
|
||||
|
|
@ -87,6 +90,7 @@ type Inputer struct {
|
|||
workers sync.Map // room ID -> *worker
|
||||
|
||||
Queryer *query.Queryer
|
||||
UserAPI userapi.RoomserverUserAPI
|
||||
}
|
||||
|
||||
// If a room consumer is inactive for a while then we will allow NATS
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ package input
|
|||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
|
@ -31,6 +32,8 @@ import (
|
|||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
userAPI "github.com/matrix-org/dendrite/userapi/api"
|
||||
|
||||
fedapi "github.com/matrix-org/dendrite/federationapi/api"
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/eventutil"
|
||||
|
|
@ -440,6 +443,13 @@ func (r *Inputer) processRoomEvent(
|
|||
}
|
||||
}
|
||||
|
||||
// If guest_access changed and is not can_join, kick all guest users.
|
||||
if event.Type() == gomatrixserverlib.MRoomGuestAccess && gjson.GetBytes(event.Content(), "guest_access").Str != "can_join" {
|
||||
if err = r.kickGuests(ctx, event, roomInfo); err != nil {
|
||||
logrus.WithError(err).Error("failed to kick guest users on m.room.guest_access revocation")
|
||||
}
|
||||
}
|
||||
|
||||
// Everything was OK — the latest events updater didn't error and
|
||||
// we've sent output events. Finally, generate a hook call.
|
||||
hooks.Run(hooks.KindNewEventPersisted, headered)
|
||||
|
|
@ -729,3 +739,98 @@ func (r *Inputer) calculateAndSetState(
|
|||
succeeded = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// kickGuests kicks guests users from m.room.guest_access rooms, if guest access is now prohibited.
|
||||
func (r *Inputer) kickGuests(ctx context.Context, event *gomatrixserverlib.Event, roomInfo *types.RoomInfo) error {
|
||||
membershipNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
memberEvents, err := r.DB.Events(ctx, membershipNIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
inputEvents := make([]api.InputRoomEvent, 0, len(memberEvents))
|
||||
latestReq := &api.QueryLatestEventsAndStateRequest{
|
||||
RoomID: event.RoomID(),
|
||||
}
|
||||
latestRes := &api.QueryLatestEventsAndStateResponse{}
|
||||
if err = r.Queryer.QueryLatestEventsAndState(ctx, latestReq, latestRes); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
prevEvents := latestRes.LatestEvents
|
||||
for _, memberEvent := range memberEvents {
|
||||
if memberEvent.StateKey() == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
localpart, senderDomain, err := gomatrixserverlib.SplitID('@', *memberEvent.StateKey())
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
accountRes := &userAPI.QueryAccountByLocalpartResponse{}
|
||||
if err = r.UserAPI.QueryAccountByLocalpart(ctx, &userAPI.QueryAccountByLocalpartRequest{
|
||||
Localpart: localpart,
|
||||
ServerName: senderDomain,
|
||||
}, accountRes); err != nil {
|
||||
return err
|
||||
}
|
||||
if accountRes.Account == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if accountRes.Account.AccountType != userAPI.AccountTypeGuest {
|
||||
continue
|
||||
}
|
||||
|
||||
var memberContent gomatrixserverlib.MemberContent
|
||||
if err = json.Unmarshal(memberEvent.Content(), &memberContent); err != nil {
|
||||
return err
|
||||
}
|
||||
memberContent.Membership = gomatrixserverlib.Leave
|
||||
|
||||
stateKey := *memberEvent.StateKey()
|
||||
fledglingEvent := &gomatrixserverlib.EventBuilder{
|
||||
RoomID: event.RoomID(),
|
||||
Type: gomatrixserverlib.MRoomMember,
|
||||
StateKey: &stateKey,
|
||||
Sender: stateKey,
|
||||
PrevEvents: prevEvents,
|
||||
}
|
||||
|
||||
if fledglingEvent.Content, err = json.Marshal(memberContent); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(fledglingEvent)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
event, err := eventutil.BuildEvent(ctx, fledglingEvent, r.Cfg.Matrix, r.SigningIdentity, time.Now(), &eventsNeeded, latestRes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
inputEvents = append(inputEvents, api.InputRoomEvent{
|
||||
Kind: api.KindNew,
|
||||
Event: event,
|
||||
Origin: senderDomain,
|
||||
SendAsServer: string(senderDomain),
|
||||
})
|
||||
prevEvents = []gomatrixserverlib.EventReference{
|
||||
event.EventReference(),
|
||||
}
|
||||
}
|
||||
|
||||
inputReq := &api.InputRoomEventsRequest{
|
||||
InputRoomEvents: inputEvents,
|
||||
Asynchronous: true, // Needs to be async, as we otherwise create a deadlock
|
||||
}
|
||||
inputRes := &api.InputRoomEventsResponse{}
|
||||
return r.InputRoomEvents(ctx, inputReq, inputRes)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ package perform
|
|||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
|
@ -270,6 +271,28 @@ func (r *Joiner) performJoinRoomByID(
|
|||
}
|
||||
}
|
||||
|
||||
// If a guest is trying to join a room, check that the room has a m.room.guest_access event
|
||||
if req.IsGuest {
|
||||
var guestAccessEvent *gomatrixserverlib.HeaderedEvent
|
||||
guestAccess := "forbidden"
|
||||
guestAccessEvent, err = r.DB.GetStateEvent(ctx, req.RoomIDOrAlias, gomatrixserverlib.MRoomGuestAccess, "")
|
||||
if (err != nil && !errors.Is(err, sql.ErrNoRows)) || guestAccessEvent == nil {
|
||||
logrus.WithError(err).Warn("unable to get m.room.guest_access event, defaulting to 'forbidden'")
|
||||
}
|
||||
if guestAccessEvent != nil {
|
||||
guestAccess = gjson.GetBytes(guestAccessEvent.Content(), "guest_access").String()
|
||||
}
|
||||
|
||||
// Servers MUST only allow guest users to join rooms if the m.room.guest_access state event
|
||||
// is present on the room and has the guest_access value can_join.
|
||||
if guestAccess != "can_join" {
|
||||
return "", "", &rsAPI.PerformError{
|
||||
Code: rsAPI.PerformErrorNotAllowed,
|
||||
Msg: "Guest access is forbidden",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If we should do a forced federated join then do that.
|
||||
var joinedVia gomatrixserverlib.ServerName
|
||||
if forceFederatedJoin {
|
||||
|
|
|
|||
|
|
@ -805,6 +805,12 @@ func (r *Queryer) QueryBulkStateContent(ctx context.Context, req *api.QueryBulkS
|
|||
return nil
|
||||
}
|
||||
|
||||
func (r *Queryer) QueryLeftUsers(ctx context.Context, req *api.QueryLeftUsersRequest, res *api.QueryLeftUsersResponse) error {
|
||||
var err error
|
||||
res.LeftUsers, err = r.DB.GetLeftUsers(ctx, req.StaleDeviceListUsers)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *Queryer) QuerySharedUsers(ctx context.Context, req *api.QuerySharedUsersRequest, res *api.QuerySharedUsersResponse) error {
|
||||
roomIDs, err := r.DB.GetRoomsByMembership(ctx, req.UserID, "join")
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -63,6 +63,7 @@ const (
|
|||
RoomserverQueryAuthChainPath = "/roomserver/queryAuthChain"
|
||||
RoomserverQueryRestrictedJoinAllowed = "/roomserver/queryRestrictedJoinAllowed"
|
||||
RoomserverQueryMembershipAtEventPath = "/roomserver/queryMembershipAtEvent"
|
||||
RoomserverQueryLeftMembersPath = "/roomserver/queryLeftMembers"
|
||||
)
|
||||
|
||||
type httpRoomserverInternalAPI struct {
|
||||
|
|
@ -553,3 +554,10 @@ func (h *httpRoomserverInternalAPI) QueryMembershipAtEvent(ctx context.Context,
|
|||
h.httpClient, ctx, request, response,
|
||||
)
|
||||
}
|
||||
|
||||
func (h *httpRoomserverInternalAPI) QueryLeftUsers(ctx context.Context, request *api.QueryLeftUsersRequest, response *api.QueryLeftUsersResponse) error {
|
||||
return httputil.CallInternalRPCAPI(
|
||||
"RoomserverQueryLeftMembers", h.roomserverURL+RoomserverQueryLeftMembersPath,
|
||||
h.httpClient, ctx, request, response,
|
||||
)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -203,4 +203,9 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router, enableMe
|
|||
RoomserverQueryMembershipAtEventPath,
|
||||
httputil.MakeInternalRPCAPI("RoomserverQueryMembershipAtEventPath", enableMetrics, r.QueryMembershipAtEvent),
|
||||
)
|
||||
|
||||
internalAPIMux.Handle(
|
||||
RoomserverQueryLeftMembersPath,
|
||||
httputil.MakeInternalRPCAPI("RoomserverQueryLeftMembersPath", enableMetrics, r.QueryLeftUsers),
|
||||
)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,27 +2,163 @@ package roomserver_test
|
|||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/matrix-org/dendrite/internal/httputil"
|
||||
"github.com/matrix-org/dendrite/setup/base"
|
||||
"github.com/matrix-org/dendrite/userapi"
|
||||
|
||||
userAPI "github.com/matrix-org/dendrite/userapi/api"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
||||
"github.com/matrix-org/dendrite/roomserver"
|
||||
"github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/dendrite/roomserver/inthttp"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage"
|
||||
"github.com/matrix-org/dendrite/setup/base"
|
||||
"github.com/matrix-org/dendrite/test"
|
||||
"github.com/matrix-org/dendrite/test/testrig"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
func mustCreateDatabase(t *testing.T, dbType test.DBType) (*base.BaseDendrite, storage.Database, func()) {
|
||||
t.Helper()
|
||||
base, close := testrig.CreateBaseDendrite(t, dbType)
|
||||
db, err := storage.Open(base, &base.Cfg.KeyServer.Database, base.Caches)
|
||||
db, err := storage.Open(base, &base.Cfg.RoomServer.Database, base.Caches)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create Database: %v", err)
|
||||
}
|
||||
return base, db, close
|
||||
}
|
||||
|
||||
func Test_SharedUsers(t *testing.T) {
|
||||
func TestUsers(t *testing.T) {
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
base, close := testrig.CreateBaseDendrite(t, dbType)
|
||||
defer close()
|
||||
rsAPI := roomserver.NewInternalAPI(base)
|
||||
// SetFederationAPI starts the room event input consumer
|
||||
rsAPI.SetFederationAPI(nil, nil)
|
||||
|
||||
t.Run("shared users", func(t *testing.T) {
|
||||
testSharedUsers(t, rsAPI)
|
||||
})
|
||||
|
||||
t.Run("kick users", func(t *testing.T) {
|
||||
usrAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, nil, rsAPI, nil)
|
||||
rsAPI.SetUserAPI(usrAPI)
|
||||
testKickUsers(t, rsAPI, usrAPI)
|
||||
})
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func testSharedUsers(t *testing.T, rsAPI api.RoomserverInternalAPI) {
|
||||
alice := test.NewUser(t)
|
||||
bob := test.NewUser(t)
|
||||
room := test.NewRoom(t, alice, test.RoomPreset(test.PresetTrustedPrivateChat))
|
||||
|
||||
// Invite and join Bob
|
||||
room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{
|
||||
"membership": "invite",
|
||||
}, test.WithStateKey(bob.ID))
|
||||
room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{
|
||||
"membership": "join",
|
||||
}, test.WithStateKey(bob.ID))
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create the room
|
||||
if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil {
|
||||
t.Errorf("failed to send events: %v", err)
|
||||
}
|
||||
|
||||
// Query the shared users for Alice, there should only be Bob.
|
||||
// This is used by the SyncAPI keychange consumer.
|
||||
res := &api.QuerySharedUsersResponse{}
|
||||
if err := rsAPI.QuerySharedUsers(ctx, &api.QuerySharedUsersRequest{UserID: alice.ID}, res); err != nil {
|
||||
t.Errorf("unable to query known users: %v", err)
|
||||
}
|
||||
if _, ok := res.UserIDsToCount[bob.ID]; !ok {
|
||||
t.Errorf("expected to find %s in shared users, but didn't: %+v", bob.ID, res.UserIDsToCount)
|
||||
}
|
||||
// Also verify that we get the expected result when specifying OtherUserIDs.
|
||||
// This is used by the SyncAPI when getting device list changes.
|
||||
if err := rsAPI.QuerySharedUsers(ctx, &api.QuerySharedUsersRequest{UserID: alice.ID, OtherUserIDs: []string{bob.ID}}, res); err != nil {
|
||||
t.Errorf("unable to query known users: %v", err)
|
||||
}
|
||||
if _, ok := res.UserIDsToCount[bob.ID]; !ok {
|
||||
t.Errorf("expected to find %s in shared users, but didn't: %+v", bob.ID, res.UserIDsToCount)
|
||||
}
|
||||
}
|
||||
|
||||
func testKickUsers(t *testing.T, rsAPI api.RoomserverInternalAPI, usrAPI userAPI.UserInternalAPI) {
|
||||
// Create users and room; Bob is going to be the guest and kicked on revocation of guest access
|
||||
alice := test.NewUser(t, test.WithAccountType(userAPI.AccountTypeUser))
|
||||
bob := test.NewUser(t, test.WithAccountType(userAPI.AccountTypeGuest))
|
||||
|
||||
room := test.NewRoom(t, alice, test.RoomPreset(test.PresetPublicChat), test.GuestsCanJoin(true))
|
||||
|
||||
// Join with the guest user
|
||||
room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{
|
||||
"membership": "join",
|
||||
}, test.WithStateKey(bob.ID))
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create the users in the userapi, so the RSAPI can query the account type later
|
||||
for _, u := range []*test.User{alice, bob} {
|
||||
localpart, serverName, _ := gomatrixserverlib.SplitID('@', u.ID)
|
||||
userRes := &userAPI.PerformAccountCreationResponse{}
|
||||
if err := usrAPI.PerformAccountCreation(ctx, &userAPI.PerformAccountCreationRequest{
|
||||
AccountType: u.AccountType,
|
||||
Localpart: localpart,
|
||||
ServerName: serverName,
|
||||
Password: "someRandomPassword",
|
||||
}, userRes); err != nil {
|
||||
t.Errorf("failed to create account: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Create the room in the database
|
||||
if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil {
|
||||
t.Errorf("failed to send events: %v", err)
|
||||
}
|
||||
|
||||
// Get the membership events BEFORE revoking guest access
|
||||
membershipRes := &api.QueryMembershipsForRoomResponse{}
|
||||
if err := rsAPI.QueryMembershipsForRoom(ctx, &api.QueryMembershipsForRoomRequest{LocalOnly: true, JoinedOnly: true, RoomID: room.ID}, membershipRes); err != nil {
|
||||
t.Errorf("failed to query membership for room: %s", err)
|
||||
}
|
||||
|
||||
// revoke guest access
|
||||
revokeEvent := room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomGuestAccess, map[string]string{"guest_access": "forbidden"}, test.WithStateKey(""))
|
||||
if err := api.SendEvents(ctx, rsAPI, api.KindNew, []*gomatrixserverlib.HeaderedEvent{revokeEvent}, "test", "test", "test", nil, false); err != nil {
|
||||
t.Errorf("failed to send events: %v", err)
|
||||
}
|
||||
|
||||
// TODO: Even though we are sending the events sync, the "kickUsers" function is sending the events async, so we need
|
||||
// to loop and wait for the events to be processed by the roomserver.
|
||||
for i := 0; i <= 20; i++ {
|
||||
// Get the membership events AFTER revoking guest access
|
||||
membershipRes2 := &api.QueryMembershipsForRoomResponse{}
|
||||
if err := rsAPI.QueryMembershipsForRoom(ctx, &api.QueryMembershipsForRoomRequest{LocalOnly: true, JoinedOnly: true, RoomID: room.ID}, membershipRes2); err != nil {
|
||||
t.Errorf("failed to query membership for room: %s", err)
|
||||
}
|
||||
|
||||
// The membership events should NOT match, as Bob (guest user) should now be kicked from the room
|
||||
if !reflect.DeepEqual(membershipRes, membershipRes2) {
|
||||
return
|
||||
}
|
||||
time.Sleep(time.Millisecond * 10)
|
||||
}
|
||||
|
||||
t.Errorf("memberships didn't change in time")
|
||||
}
|
||||
|
||||
func Test_QueryLeftUsers(t *testing.T) {
|
||||
alice := test.NewUser(t)
|
||||
bob := test.NewUser(t)
|
||||
room := test.NewRoom(t, alice, test.RoomPreset(test.PresetTrustedPrivateChat))
|
||||
|
|
@ -48,22 +184,42 @@ func Test_SharedUsers(t *testing.T) {
|
|||
t.Fatalf("failed to send events: %v", err)
|
||||
}
|
||||
|
||||
// Query the shared users for Alice, there should only be Bob.
|
||||
// This is used by the SyncAPI keychange consumer.
|
||||
res := &api.QuerySharedUsersResponse{}
|
||||
if err := rsAPI.QuerySharedUsers(ctx, &api.QuerySharedUsersRequest{UserID: alice.ID}, res); err != nil {
|
||||
t.Fatalf("unable to query known users: %v", err)
|
||||
}
|
||||
if _, ok := res.UserIDsToCount[bob.ID]; !ok {
|
||||
t.Fatalf("expected to find %s in shared users, but didn't: %+v", bob.ID, res.UserIDsToCount)
|
||||
}
|
||||
// Also verify that we get the expected result when specifying OtherUserIDs.
|
||||
// This is used by the SyncAPI when getting device list changes.
|
||||
if err := rsAPI.QuerySharedUsers(ctx, &api.QuerySharedUsersRequest{UserID: alice.ID, OtherUserIDs: []string{bob.ID}}, res); err != nil {
|
||||
t.Fatalf("unable to query known users: %v", err)
|
||||
}
|
||||
if _, ok := res.UserIDsToCount[bob.ID]; !ok {
|
||||
t.Fatalf("expected to find %s in shared users, but didn't: %+v", bob.ID, res.UserIDsToCount)
|
||||
// Query the left users, there should only be "@idontexist:test",
|
||||
// as Alice and Bob are still joined.
|
||||
res := &api.QueryLeftUsersResponse{}
|
||||
leftUserID := "@idontexist:test"
|
||||
getLeftUsersList := []string{alice.ID, bob.ID, leftUserID}
|
||||
|
||||
testCase := func(rsAPI api.RoomserverInternalAPI) {
|
||||
if err := rsAPI.QueryLeftUsers(ctx, &api.QueryLeftUsersRequest{StaleDeviceListUsers: getLeftUsersList}, res); err != nil {
|
||||
t.Fatalf("unable to query left users: %v", err)
|
||||
}
|
||||
wantCount := 1
|
||||
if count := len(res.LeftUsers); count > wantCount {
|
||||
t.Fatalf("unexpected left users count: want %d, got %d", wantCount, count)
|
||||
}
|
||||
if res.LeftUsers[0] != leftUserID {
|
||||
t.Fatalf("unexpected left users : want %s, got %s", leftUserID, res.LeftUsers[0])
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("HTTP API", func(t *testing.T) {
|
||||
router := mux.NewRouter().PathPrefix(httputil.InternalPathPrefix).Subrouter()
|
||||
roomserver.AddInternalRoutes(router, rsAPI, false)
|
||||
apiURL, cancel := test.ListenAndServe(t, router, false)
|
||||
defer cancel()
|
||||
httpAPI, err := inthttp.NewRoomserverClient(apiURL, &http.Client{Timeout: time.Second * 5}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create HTTP client")
|
||||
}
|
||||
testCase(httpAPI)
|
||||
})
|
||||
t.Run("Monolith", func(t *testing.T) {
|
||||
testCase(rsAPI)
|
||||
// also test tracing
|
||||
traceAPI := &api.RoomserverInternalAPITrace{Impl: rsAPI}
|
||||
testCase(traceAPI)
|
||||
})
|
||||
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -172,5 +172,6 @@ type Database interface {
|
|||
ForgetRoom(ctx context.Context, userID, roomID string, forget bool) error
|
||||
|
||||
GetHistoryVisibilityState(ctx context.Context, roomInfo *types.RoomInfo, eventID string, domain string) ([]*gomatrixserverlib.Event, error)
|
||||
GetLeftUsers(ctx context.Context, userIDs []string) ([]string, error)
|
||||
UpgradeRoom(ctx context.Context, oldRoomID, newRoomID, eventSender string) error
|
||||
}
|
||||
|
|
|
|||
|
|
@ -21,12 +21,13 @@ import (
|
|||
"fmt"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/postgres/deltas"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const membershipSchema = `
|
||||
|
|
@ -157,6 +158,12 @@ const selectServerInRoomSQL = "" +
|
|||
" JOIN roomserver_event_state_keys ON roomserver_membership.target_nid = roomserver_event_state_keys.event_state_key_nid" +
|
||||
" WHERE membership_nid = $1 AND room_nid = $2 AND event_state_key LIKE '%:' || $3 LIMIT 1"
|
||||
|
||||
const selectJoinedUsersSQL = `
|
||||
SELECT DISTINCT target_nid
|
||||
FROM roomserver_membership m
|
||||
WHERE membership_nid > $1 AND target_nid = ANY($2)
|
||||
`
|
||||
|
||||
type membershipStatements struct {
|
||||
insertMembershipStmt *sql.Stmt
|
||||
selectMembershipForUpdateStmt *sql.Stmt
|
||||
|
|
@ -174,6 +181,7 @@ type membershipStatements struct {
|
|||
selectLocalServerInRoomStmt *sql.Stmt
|
||||
selectServerInRoomStmt *sql.Stmt
|
||||
deleteMembershipStmt *sql.Stmt
|
||||
selectJoinedUsersStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func CreateMembershipTable(db *sql.DB) error {
|
||||
|
|
@ -209,9 +217,33 @@ func PrepareMembershipTable(db *sql.DB) (tables.Membership, error) {
|
|||
{&s.selectLocalServerInRoomStmt, selectLocalServerInRoomSQL},
|
||||
{&s.selectServerInRoomStmt, selectServerInRoomSQL},
|
||||
{&s.deleteMembershipStmt, deleteMembershipSQL},
|
||||
{&s.selectJoinedUsersStmt, selectJoinedUsersSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *membershipStatements) SelectJoinedUsers(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
targetUserNIDs []types.EventStateKeyNID,
|
||||
) ([]types.EventStateKeyNID, error) {
|
||||
result := make([]types.EventStateKeyNID, 0, len(targetUserNIDs))
|
||||
|
||||
stmt := sqlutil.TxStmt(txn, s.selectJoinedUsersStmt)
|
||||
rows, err := stmt.QueryContext(ctx, tables.MembershipStateLeaveOrBan, pq.Array(targetUserNIDs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "SelectJoinedUsers: rows.close() failed")
|
||||
var targetNID types.EventStateKeyNID
|
||||
for rows.Next() {
|
||||
if err = rows.Scan(&targetNID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, targetNID)
|
||||
}
|
||||
|
||||
return result, rows.Err()
|
||||
}
|
||||
|
||||
func (s *membershipStatements) InsertMembership(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||
|
|
|
|||
|
|
@ -1365,6 +1365,43 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs [
|
|||
return result, nil
|
||||
}
|
||||
|
||||
// GetLeftUsers calculates users we (the server) don't share a room with anymore.
|
||||
func (d *Database) GetLeftUsers(ctx context.Context, userIDs []string) ([]string, error) {
|
||||
// Get the userNID for all users with a stale device list
|
||||
stateKeyNIDMap, err := d.EventStateKeyNIDs(ctx, userIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userNIDs := make([]types.EventStateKeyNID, 0, len(stateKeyNIDMap))
|
||||
userNIDtoUserID := make(map[types.EventStateKeyNID]string, len(stateKeyNIDMap))
|
||||
// Create a map from userNID -> userID
|
||||
for userID, nid := range stateKeyNIDMap {
|
||||
userNIDs = append(userNIDs, nid)
|
||||
userNIDtoUserID[nid] = userID
|
||||
}
|
||||
|
||||
// Get all users whose membership is still join, knock or invite.
|
||||
stillJoinedUsersNIDs, err := d.MembershipTable.SelectJoinedUsers(ctx, nil, userNIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Remove joined users from the "user with stale devices" list, which contains left AND joined users
|
||||
for _, joinedUser := range stillJoinedUsersNIDs {
|
||||
delete(userNIDtoUserID, joinedUser)
|
||||
}
|
||||
|
||||
// The users still in our userNIDtoUserID map are the users we don't share a room with anymore,
|
||||
// and the return value we are looking for.
|
||||
leftUsers := make([]string, 0, len(userNIDtoUserID))
|
||||
for _, userID := range userNIDtoUserID {
|
||||
leftUsers = append(leftUsers, userID)
|
||||
}
|
||||
|
||||
return leftUsers, nil
|
||||
}
|
||||
|
||||
// GetLocalServerInRoom returns true if we think we're in a given room or false otherwise.
|
||||
func (d *Database) GetLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) {
|
||||
return d.MembershipTable.SelectLocalServerInRoom(ctx, nil, roomNID)
|
||||
|
|
|
|||
96
roomserver/storage/shared/storage_test.go
Normal file
96
roomserver/storage/shared/storage_test.go
Normal file
|
|
@ -0,0 +1,96 @@
|
|||
package shared_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/postgres"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/sqlite3"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/test"
|
||||
"github.com/matrix-org/dendrite/test/testrig"
|
||||
)
|
||||
|
||||
func mustCreateRoomserverDatabase(t *testing.T, dbType test.DBType) (*shared.Database, func()) {
|
||||
t.Helper()
|
||||
|
||||
connStr, clearDB := test.PrepareDBConnectionString(t, dbType)
|
||||
base, _, _ := testrig.Base(nil)
|
||||
dbOpts := &config.DatabaseOptions{ConnectionString: config.DataSource(connStr)}
|
||||
|
||||
db, err := sqlutil.Open(dbOpts, sqlutil.NewExclusiveWriter())
|
||||
assert.NoError(t, err)
|
||||
|
||||
var membershipTable tables.Membership
|
||||
var stateKeyTable tables.EventStateKeys
|
||||
switch dbType {
|
||||
case test.DBTypePostgres:
|
||||
err = postgres.CreateEventStateKeysTable(db)
|
||||
assert.NoError(t, err)
|
||||
err = postgres.CreateMembershipTable(db)
|
||||
assert.NoError(t, err)
|
||||
membershipTable, err = postgres.PrepareMembershipTable(db)
|
||||
assert.NoError(t, err)
|
||||
stateKeyTable, err = postgres.PrepareEventStateKeysTable(db)
|
||||
case test.DBTypeSQLite:
|
||||
err = sqlite3.CreateEventStateKeysTable(db)
|
||||
assert.NoError(t, err)
|
||||
err = sqlite3.CreateMembershipTable(db)
|
||||
assert.NoError(t, err)
|
||||
membershipTable, err = sqlite3.PrepareMembershipTable(db)
|
||||
assert.NoError(t, err)
|
||||
stateKeyTable, err = sqlite3.PrepareEventStateKeysTable(db)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
|
||||
return &shared.Database{
|
||||
DB: db,
|
||||
EventStateKeysTable: stateKeyTable,
|
||||
MembershipTable: membershipTable,
|
||||
Writer: sqlutil.NewExclusiveWriter(),
|
||||
}, func() {
|
||||
err := base.Close()
|
||||
assert.NoError(t, err)
|
||||
clearDB()
|
||||
err = db.Close()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_GetLeftUsers(t *testing.T) {
|
||||
alice := test.NewUser(t)
|
||||
bob := test.NewUser(t)
|
||||
charlie := test.NewUser(t)
|
||||
|
||||
ctx := context.Background()
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := mustCreateRoomserverDatabase(t, dbType)
|
||||
defer close()
|
||||
|
||||
// Create dummy entries
|
||||
for _, user := range []*test.User{alice, bob, charlie} {
|
||||
nid, err := db.EventStateKeysTable.InsertEventStateKeyNID(ctx, nil, user.ID)
|
||||
assert.NoError(t, err)
|
||||
err = db.MembershipTable.InsertMembership(ctx, nil, 1, nid, true)
|
||||
assert.NoError(t, err)
|
||||
// We must update the membership with a non-zero event NID or it will get filtered out in later queries
|
||||
membershipNID := tables.MembershipStateLeaveOrBan
|
||||
if user == alice {
|
||||
membershipNID = tables.MembershipStateJoin
|
||||
}
|
||||
_, err = db.MembershipTable.UpdateMembership(ctx, nil, 1, nid, nid, membershipNID, 1, false)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// Now try to get the left users, this should be Bob and Charlie, since they have a "leave" membership
|
||||
expectedUserIDs := []string{bob.ID, charlie.ID}
|
||||
leftUsers, err := db.GetLeftUsers(context.Background(), []string{alice.ID, bob.ID, charlie.ID})
|
||||
assert.NoError(t, err)
|
||||
assert.ElementsMatch(t, expectedUserIDs, leftUsers)
|
||||
})
|
||||
}
|
||||
|
|
@ -21,12 +21,13 @@ import (
|
|||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/sqlite3/deltas"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const membershipSchema = `
|
||||
|
|
@ -133,6 +134,12 @@ const selectServerInRoomSQL = "" +
|
|||
const deleteMembershipSQL = "" +
|
||||
"DELETE FROM roomserver_membership WHERE room_nid = $1 AND target_nid = $2"
|
||||
|
||||
const selectJoinedUsersSQL = `
|
||||
SELECT DISTINCT target_nid
|
||||
FROM roomserver_membership m
|
||||
WHERE membership_nid > $1 AND target_nid IN ($2)
|
||||
`
|
||||
|
||||
type membershipStatements struct {
|
||||
db *sql.DB
|
||||
insertMembershipStmt *sql.Stmt
|
||||
|
|
@ -149,6 +156,7 @@ type membershipStatements struct {
|
|||
selectLocalServerInRoomStmt *sql.Stmt
|
||||
selectServerInRoomStmt *sql.Stmt
|
||||
deleteMembershipStmt *sql.Stmt
|
||||
// selectJoinedUsersStmt *sql.Stmt // Prepared at runtime
|
||||
}
|
||||
|
||||
func CreateMembershipTable(db *sql.DB) error {
|
||||
|
|
@ -412,3 +420,40 @@ func (s *membershipStatements) DeleteMembership(
|
|||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *membershipStatements) SelectJoinedUsers(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
targetUserNIDs []types.EventStateKeyNID,
|
||||
) ([]types.EventStateKeyNID, error) {
|
||||
result := make([]types.EventStateKeyNID, 0, len(targetUserNIDs))
|
||||
|
||||
qry := strings.Replace(selectJoinedUsersSQL, "($2)", sqlutil.QueryVariadicOffset(len(targetUserNIDs), 1), 1)
|
||||
|
||||
stmt, err := s.db.Prepare(qry)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, stmt, "SelectJoinedUsers: stmt.Close failed")
|
||||
|
||||
params := make([]any, len(targetUserNIDs)+1)
|
||||
params[0] = tables.MembershipStateLeaveOrBan
|
||||
for i := range targetUserNIDs {
|
||||
params[i+1] = targetUserNIDs[i]
|
||||
}
|
||||
|
||||
stmt = sqlutil.TxStmt(txn, stmt)
|
||||
rows, err := stmt.QueryContext(ctx, params...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "SelectJoinedUsers: rows.close() failed")
|
||||
var targetNID types.EventStateKeyNID
|
||||
for rows.Next() {
|
||||
if err = rows.Scan(&targetNID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, targetNID)
|
||||
}
|
||||
|
||||
return result, rows.Err()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -144,6 +144,7 @@ type Membership interface {
|
|||
SelectLocalServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) (bool, error)
|
||||
SelectServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error)
|
||||
DeleteMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) error
|
||||
SelectJoinedUsers(ctx context.Context, txn *sql.Tx, targetUserNIDs []types.EventStateKeyNID) ([]types.EventStateKeyNID, error)
|
||||
}
|
||||
|
||||
type Published interface {
|
||||
|
|
|
|||
|
|
@ -129,5 +129,11 @@ func TestMembershipTable(t *testing.T) {
|
|||
knownUsers, err := tab.SelectKnownUsers(ctx, nil, userNIDs[0], "localhost", 2)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, len(knownUsers))
|
||||
|
||||
// get users we share a room with, given their userNID
|
||||
joinedUsers, err := tab.SelectJoinedUsers(ctx, nil, userNIDs)
|
||||
assert.NoError(t, err)
|
||||
// Only userNIDs[0] is actually joined, so we only expect this userNID
|
||||
assert.Equal(t, userNIDs[:1], joinedUsers)
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -174,7 +174,7 @@ func (c *Global) SigningIdentityFor(serverName gomatrixserverlib.ServerName) (*g
|
|||
return id, nil
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("no signing identity %q", serverName)
|
||||
return nil, fmt.Errorf("no signing identity for %q", serverName)
|
||||
}
|
||||
|
||||
func (c *Global) SigningIdentities() []*gomatrixserverlib.SigningIdentity {
|
||||
|
|
|
|||
|
|
@ -16,8 +16,10 @@ package config
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"gopkg.in/yaml.v2"
|
||||
)
|
||||
|
||||
|
|
@ -290,3 +292,55 @@ func TestUnmarshalDataUnit(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test_SigningIdentityFor(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
virtualHosts []*VirtualHost
|
||||
serverName gomatrixserverlib.ServerName
|
||||
want *gomatrixserverlib.SigningIdentity
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "no virtual hosts defined",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "no identity found",
|
||||
serverName: gomatrixserverlib.ServerName("doesnotexist"),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "found identity",
|
||||
serverName: gomatrixserverlib.ServerName("main"),
|
||||
want: &gomatrixserverlib.SigningIdentity{ServerName: "main"},
|
||||
},
|
||||
{
|
||||
name: "identity found on virtual hosts",
|
||||
serverName: gomatrixserverlib.ServerName("vh2"),
|
||||
virtualHosts: []*VirtualHost{
|
||||
{SigningIdentity: gomatrixserverlib.SigningIdentity{ServerName: "vh1"}},
|
||||
{SigningIdentity: gomatrixserverlib.SigningIdentity{ServerName: "vh2"}},
|
||||
},
|
||||
want: &gomatrixserverlib.SigningIdentity{ServerName: "vh2"},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Global{
|
||||
VirtualHosts: tt.virtualHosts,
|
||||
SigningIdentity: gomatrixserverlib.SigningIdentity{
|
||||
ServerName: "main",
|
||||
},
|
||||
}
|
||||
got, err := c.SigningIdentityFor(tt.serverName)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("SigningIdentityFor() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("SigningIdentityFor() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -78,7 +78,7 @@ func (s *PresenceConsumer) Start() error {
|
|||
// Normal NATS subscription, used by Request/Reply
|
||||
_, err := s.nats.Subscribe(s.requestTopic, func(msg *nats.Msg) {
|
||||
userID := msg.Header.Get(jetstream.UserID)
|
||||
presence, err := s.db.GetPresence(context.Background(), userID)
|
||||
presences, err := s.db.GetPresences(context.Background(), []string{userID})
|
||||
m := &nats.Msg{
|
||||
Header: nats.Header{},
|
||||
}
|
||||
|
|
@ -89,10 +89,12 @@ func (s *PresenceConsumer) Start() error {
|
|||
}
|
||||
return
|
||||
}
|
||||
if presence == nil {
|
||||
presence = &types.PresenceInternal{
|
||||
UserID: userID,
|
||||
}
|
||||
|
||||
presence := &types.PresenceInternal{
|
||||
UserID: userID,
|
||||
}
|
||||
if len(presences) > 0 {
|
||||
presence = presences[0]
|
||||
}
|
||||
|
||||
deviceRes := api.QueryDevicesResponse{}
|
||||
|
|
|
|||
|
|
@ -106,7 +106,7 @@ type DatabaseTransaction interface {
|
|||
SelectMembershipForUser(ctx context.Context, roomID, userID string, pos int64) (membership string, topologicalPos int, err error)
|
||||
// getUserUnreadNotificationCountsForRooms returns the unread notifications for the given rooms
|
||||
GetUserUnreadNotificationCountsForRooms(ctx context.Context, userID string, roomIDs map[string]string) (map[string]*eventutil.NotificationData, error)
|
||||
GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error)
|
||||
GetPresences(ctx context.Context, userID []string) ([]*types.PresenceInternal, error)
|
||||
PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error)
|
||||
RelationsFor(ctx context.Context, roomID, eventID, relType, eventType string, from, to types.StreamPosition, backwards bool, limit int) (events []types.StreamEvent, prevBatch, nextBatch string, err error)
|
||||
}
|
||||
|
|
@ -186,7 +186,7 @@ type Database interface {
|
|||
}
|
||||
|
||||
type Presence interface {
|
||||
GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error)
|
||||
GetPresences(ctx context.Context, userIDs []string) ([]*types.PresenceInternal, error)
|
||||
UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (types.StreamPosition, error)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -19,10 +19,12 @@ import (
|
|||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/syncapi/types"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const presenceSchema = `
|
||||
|
|
@ -63,9 +65,9 @@ const upsertPresenceFromSyncSQL = "" +
|
|||
" RETURNING id"
|
||||
|
||||
const selectPresenceForUserSQL = "" +
|
||||
"SELECT presence, status_msg, last_active_ts" +
|
||||
"SELECT user_id, presence, status_msg, last_active_ts" +
|
||||
" FROM syncapi_presence" +
|
||||
" WHERE user_id = $1 LIMIT 1"
|
||||
" WHERE user_id = ANY($1)"
|
||||
|
||||
const selectMaxPresenceSQL = "" +
|
||||
"SELECT COALESCE(MAX(id), 0) FROM syncapi_presence"
|
||||
|
|
@ -119,20 +121,28 @@ func (p *presenceStatements) UpsertPresence(
|
|||
return
|
||||
}
|
||||
|
||||
// GetPresenceForUser returns the current presence of a user.
|
||||
func (p *presenceStatements) GetPresenceForUser(
|
||||
// GetPresenceForUsers returns the current presence for a list of users.
|
||||
// If the user doesn't have a presence status yet, it is omitted from the response.
|
||||
func (p *presenceStatements) GetPresenceForUsers(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
userID string,
|
||||
) (*types.PresenceInternal, error) {
|
||||
result := &types.PresenceInternal{
|
||||
UserID: userID,
|
||||
}
|
||||
userIDs []string,
|
||||
) ([]*types.PresenceInternal, error) {
|
||||
result := make([]*types.PresenceInternal, 0, len(userIDs))
|
||||
stmt := sqlutil.TxStmt(txn, p.selectPresenceForUsersStmt)
|
||||
err := stmt.QueryRowContext(ctx, userID).Scan(&result.Presence, &result.ClientFields.StatusMsg, &result.LastActiveTS)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
rows, err := stmt.QueryContext(ctx, pq.Array(userIDs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "GetPresenceForUsers: rows.close() failed")
|
||||
|
||||
for rows.Next() {
|
||||
presence := &types.PresenceInternal{}
|
||||
if err = rows.Scan(&presence.UserID, &presence.Presence, &presence.ClientFields.StatusMsg, &presence.LastActiveTS); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
presence.ClientFields.Presence = presence.Presence.String()
|
||||
result = append(result, presence)
|
||||
}
|
||||
result.ClientFields.Presence = result.Presence.String()
|
||||
return result, err
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -57,31 +57,23 @@ type Database struct {
|
|||
}
|
||||
|
||||
func (d *Database) NewDatabaseSnapshot(ctx context.Context) (*DatabaseTransaction, error) {
|
||||
return d.NewDatabaseTransaction(ctx)
|
||||
|
||||
/*
|
||||
TODO: Repeatable read is probably the right thing to do here,
|
||||
but it seems to cause some problems with the invite tests, so
|
||||
need to investigate that further.
|
||||
|
||||
txn, err := d.DB.BeginTx(ctx, &sql.TxOptions{
|
||||
// Set the isolation level so that we see a snapshot of the database.
|
||||
// In PostgreSQL repeatable read transactions will see a snapshot taken
|
||||
// at the first query, and since the transaction is read-only it can't
|
||||
// run into any serialisation errors.
|
||||
// https://www.postgresql.org/docs/9.5/static/transaction-iso.html#XACT-REPEATABLE-READ
|
||||
Isolation: sql.LevelRepeatableRead,
|
||||
ReadOnly: true,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &DatabaseTransaction{
|
||||
Database: d,
|
||||
ctx: ctx,
|
||||
txn: txn,
|
||||
}, nil
|
||||
*/
|
||||
txn, err := d.DB.BeginTx(ctx, &sql.TxOptions{
|
||||
// Set the isolation level so that we see a snapshot of the database.
|
||||
// In PostgreSQL repeatable read transactions will see a snapshot taken
|
||||
// at the first query, and since the transaction is read-only it can't
|
||||
// run into any serialisation errors.
|
||||
// https://www.postgresql.org/docs/9.5/static/transaction-iso.html#XACT-REPEATABLE-READ
|
||||
Isolation: sql.LevelRepeatableRead,
|
||||
ReadOnly: true,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &DatabaseTransaction{
|
||||
Database: d,
|
||||
ctx: ctx,
|
||||
txn: txn,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (d *Database) NewDatabaseTransaction(ctx context.Context) (*DatabaseTransaction, error) {
|
||||
|
|
@ -572,8 +564,8 @@ func (d *Database) UpdatePresence(ctx context.Context, userID string, presence t
|
|||
return pos, err
|
||||
}
|
||||
|
||||
func (d *Database) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) {
|
||||
return d.Presence.GetPresenceForUser(ctx, nil, userID)
|
||||
func (d *Database) GetPresences(ctx context.Context, userIDs []string) ([]*types.PresenceInternal, error) {
|
||||
return d.Presence.GetPresenceForUsers(ctx, nil, userIDs)
|
||||
}
|
||||
|
||||
func (d *Database) SelectMembershipForUser(ctx context.Context, roomID, userID string, pos int64) (membership string, topologicalPos int, err error) {
|
||||
|
|
|
|||
|
|
@ -596,8 +596,8 @@ func (d *DatabaseTransaction) GetUserUnreadNotificationCountsForRooms(ctx contex
|
|||
return d.NotificationData.SelectUserUnreadCountsForRooms(ctx, d.txn, userID, roomIDs)
|
||||
}
|
||||
|
||||
func (d *DatabaseTransaction) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) {
|
||||
return d.Presence.GetPresenceForUser(ctx, d.txn, userID)
|
||||
func (d *DatabaseTransaction) GetPresences(ctx context.Context, userIDs []string) ([]*types.PresenceInternal, error) {
|
||||
return d.Presence.GetPresenceForUsers(ctx, d.txn, userIDs)
|
||||
}
|
||||
|
||||
func (d *DatabaseTransaction) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) {
|
||||
|
|
|
|||
|
|
@ -17,12 +17,14 @@ package sqlite3
|
|||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/syncapi/types"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const presenceSchema = `
|
||||
|
|
@ -62,9 +64,9 @@ const upsertPresenceFromSyncSQL = "" +
|
|||
" RETURNING id"
|
||||
|
||||
const selectPresenceForUserSQL = "" +
|
||||
"SELECT presence, status_msg, last_active_ts" +
|
||||
"SELECT user_id, presence, status_msg, last_active_ts" +
|
||||
" FROM syncapi_presence" +
|
||||
" WHERE user_id = $1 LIMIT 1"
|
||||
" WHERE user_id IN ($1)"
|
||||
|
||||
const selectMaxPresenceSQL = "" +
|
||||
"SELECT COALESCE(MAX(id), 0) FROM syncapi_presence"
|
||||
|
|
@ -134,20 +136,38 @@ func (p *presenceStatements) UpsertPresence(
|
|||
return
|
||||
}
|
||||
|
||||
// GetPresenceForUser returns the current presence of a user.
|
||||
func (p *presenceStatements) GetPresenceForUser(
|
||||
// GetPresenceForUsers returns the current presence for a list of users.
|
||||
// If the user doesn't have a presence status yet, it is omitted from the response.
|
||||
func (p *presenceStatements) GetPresenceForUsers(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
userID string,
|
||||
) (*types.PresenceInternal, error) {
|
||||
result := &types.PresenceInternal{
|
||||
UserID: userID,
|
||||
userIDs []string,
|
||||
) ([]*types.PresenceInternal, error) {
|
||||
qry := strings.Replace(selectPresenceForUserSQL, "($1)", sqlutil.QueryVariadic(len(userIDs)), 1)
|
||||
prepStmt, err := p.db.Prepare(qry)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stmt := sqlutil.TxStmt(txn, p.selectPresenceForUsersStmt)
|
||||
err := stmt.QueryRowContext(ctx, userID).Scan(&result.Presence, &result.ClientFields.StatusMsg, &result.LastActiveTS)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
defer internal.CloseAndLogIfError(ctx, prepStmt, "GetPresenceForUsers: stmt.close() failed")
|
||||
|
||||
params := make([]interface{}, len(userIDs))
|
||||
for i := range userIDs {
|
||||
params[i] = userIDs[i]
|
||||
}
|
||||
|
||||
rows, err := sqlutil.TxStmt(txn, prepStmt).QueryContext(ctx, params...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "GetPresenceForUsers: rows.close() failed")
|
||||
result := make([]*types.PresenceInternal, 0, len(userIDs))
|
||||
for rows.Next() {
|
||||
presence := &types.PresenceInternal{}
|
||||
if err = rows.Scan(&presence.UserID, &presence.Presence, &presence.ClientFields.StatusMsg, &presence.LastActiveTS); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
presence.ClientFields.Presence = presence.Presence.String()
|
||||
result = append(result, presence)
|
||||
}
|
||||
result.ClientFields.Presence = result.Presence.String()
|
||||
return result, err
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -207,7 +207,7 @@ type Ignores interface {
|
|||
|
||||
type Presence interface {
|
||||
UpsertPresence(ctx context.Context, txn *sql.Tx, userID string, statusMsg *string, presence types.Presence, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (pos types.StreamPosition, err error)
|
||||
GetPresenceForUser(ctx context.Context, txn *sql.Tx, userID string) (presence *types.PresenceInternal, err error)
|
||||
GetPresenceForUsers(ctx context.Context, txn *sql.Tx, userIDs []string) (presence []*types.PresenceInternal, err error)
|
||||
GetMaxPresenceID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error)
|
||||
GetPresenceAfter(ctx context.Context, txn *sql.Tx, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (presences map[string]*types.PresenceInternal, err error)
|
||||
}
|
||||
|
|
|
|||
136
syncapi/storage/tables/presence_table_test.go
Normal file
136
syncapi/storage/tables/presence_table_test.go
Normal file
|
|
@ -0,0 +1,136 @@
|
|||
package tables_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/syncapi/storage/postgres"
|
||||
"github.com/matrix-org/dendrite/syncapi/storage/sqlite3"
|
||||
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
||||
"github.com/matrix-org/dendrite/syncapi/types"
|
||||
"github.com/matrix-org/dendrite/test"
|
||||
)
|
||||
|
||||
func mustPresenceTable(t *testing.T, dbType test.DBType) (tables.Presence, func()) {
|
||||
t.Helper()
|
||||
connStr, close := test.PrepareDBConnectionString(t, dbType)
|
||||
db, err := sqlutil.Open(&config.DatabaseOptions{
|
||||
ConnectionString: config.DataSource(connStr),
|
||||
}, sqlutil.NewExclusiveWriter())
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open db: %s", err)
|
||||
}
|
||||
|
||||
var tab tables.Presence
|
||||
switch dbType {
|
||||
case test.DBTypePostgres:
|
||||
tab, err = postgres.NewPostgresPresenceTable(db)
|
||||
case test.DBTypeSQLite:
|
||||
var stream sqlite3.StreamIDStatements
|
||||
if err = stream.Prepare(db); err != nil {
|
||||
t.Fatalf("failed to prepare stream stmts: %s", err)
|
||||
}
|
||||
tab, err = sqlite3.NewSqlitePresenceTable(db, &stream)
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("failed to make new table: %s", err)
|
||||
}
|
||||
return tab, close
|
||||
}
|
||||
|
||||
func TestPresence(t *testing.T) {
|
||||
alice := test.NewUser(t)
|
||||
bob := test.NewUser(t)
|
||||
ctx := context.Background()
|
||||
|
||||
statusMsg := "Hello World!"
|
||||
timestamp := gomatrixserverlib.AsTimestamp(time.Now())
|
||||
|
||||
var txn *sql.Tx
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
tab, closeDB := mustPresenceTable(t, dbType)
|
||||
defer closeDB()
|
||||
|
||||
// Insert some presences
|
||||
pos, err := tab.UpsertPresence(ctx, txn, alice.ID, &statusMsg, types.PresenceOnline, timestamp, false)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
wantPos := types.StreamPosition(1)
|
||||
if pos != wantPos {
|
||||
t.Errorf("expected pos to be %d, got %d", wantPos, pos)
|
||||
}
|
||||
pos, err = tab.UpsertPresence(ctx, txn, bob.ID, &statusMsg, types.PresenceOnline, timestamp, false)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
wantPos = 2
|
||||
if pos != wantPos {
|
||||
t.Errorf("expected pos to be %d, got %d", wantPos, pos)
|
||||
}
|
||||
|
||||
// verify the expected max presence ID
|
||||
maxPos, err := tab.GetMaxPresenceID(ctx, txn)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if maxPos != wantPos {
|
||||
t.Errorf("expected max pos to be %d, got %d", wantPos, maxPos)
|
||||
}
|
||||
|
||||
// This should increment the position
|
||||
pos, err = tab.UpsertPresence(ctx, txn, bob.ID, &statusMsg, types.PresenceOnline, timestamp, true)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
wantPos = pos
|
||||
if wantPos <= maxPos {
|
||||
t.Errorf("expected pos to be %d incremented, got %d", wantPos, pos)
|
||||
}
|
||||
|
||||
// This should return only Bobs status
|
||||
presences, err := tab.GetPresenceAfter(ctx, txn, maxPos, gomatrixserverlib.EventFilter{Limit: 10})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if c := len(presences); c > 1 {
|
||||
t.Errorf("expected only one presence, got %d", c)
|
||||
}
|
||||
|
||||
// Validate the response
|
||||
wantPresence := &types.PresenceInternal{
|
||||
UserID: bob.ID,
|
||||
Presence: types.PresenceOnline,
|
||||
StreamPos: wantPos,
|
||||
LastActiveTS: timestamp,
|
||||
ClientFields: types.PresenceClientResponse{
|
||||
LastActiveAgo: 0,
|
||||
Presence: types.PresenceOnline.String(),
|
||||
StatusMsg: &statusMsg,
|
||||
},
|
||||
}
|
||||
if !reflect.DeepEqual(wantPresence, presences[bob.ID]) {
|
||||
t.Errorf("unexpected presence result:\n%+v, want\n%+v", presences[bob.ID], wantPresence)
|
||||
}
|
||||
|
||||
// Try getting presences for existing and non-existing users
|
||||
getUsers := []string{alice.ID, bob.ID, "@doesntexist:test"}
|
||||
presencesForUsers, err := tab.GetPresenceForUsers(ctx, nil, getUsers)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if len(presencesForUsers) >= len(getUsers) {
|
||||
t.Errorf("expected less presences, but they are the same/more as requested: %d >= %d", len(presencesForUsers), len(getUsers))
|
||||
}
|
||||
})
|
||||
|
||||
}
|
||||
|
|
@ -17,6 +17,7 @@ package streams
|
|||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
|
@ -70,39 +71,25 @@ func (p *PresenceStreamProvider) IncrementalSync(
|
|||
return from
|
||||
}
|
||||
|
||||
if len(presences) == 0 {
|
||||
getPresenceForUsers, err := p.getNeededUsersFromRequest(ctx, req, presences)
|
||||
if err != nil {
|
||||
req.Log.WithError(err).Error("getNeededUsersFromRequest failed")
|
||||
return from
|
||||
}
|
||||
|
||||
// Got no presence between range and no presence to get from the database
|
||||
if len(getPresenceForUsers) == 0 && len(presences) == 0 {
|
||||
return to
|
||||
}
|
||||
|
||||
// add newly joined rooms user presences
|
||||
newlyJoined := joinedRooms(req.Response, req.Device.UserID)
|
||||
if len(newlyJoined) > 0 {
|
||||
// TODO: Check if this is working better than before.
|
||||
if err = p.notifier.LoadRooms(ctx, p.DB, newlyJoined); err != nil {
|
||||
req.Log.WithError(err).Error("unable to refresh notifier lists")
|
||||
return from
|
||||
}
|
||||
NewlyJoinedLoop:
|
||||
for _, roomID := range newlyJoined {
|
||||
roomUsers := p.notifier.JoinedUsers(roomID)
|
||||
for i := range roomUsers {
|
||||
// we already got a presence from this user
|
||||
if _, ok := presences[roomUsers[i]]; ok {
|
||||
continue
|
||||
}
|
||||
// Bear in mind that this might return nil, but at least populating
|
||||
// a nil means that there's a map entry so we won't repeat this call.
|
||||
presences[roomUsers[i]], err = snapshot.GetPresence(ctx, roomUsers[i])
|
||||
if err != nil {
|
||||
req.Log.WithError(err).Error("unable to query presence for user")
|
||||
_ = snapshot.Rollback()
|
||||
return from
|
||||
}
|
||||
if len(presences) > req.Filter.Presence.Limit {
|
||||
break NewlyJoinedLoop
|
||||
}
|
||||
}
|
||||
}
|
||||
dbPresences, err := snapshot.GetPresences(ctx, getPresenceForUsers)
|
||||
if err != nil {
|
||||
req.Log.WithError(err).Error("unable to query presence for user")
|
||||
_ = snapshot.Rollback()
|
||||
return from
|
||||
}
|
||||
for _, presence := range dbPresences {
|
||||
presences[presence.UserID] = presence
|
||||
}
|
||||
|
||||
lastPos := from
|
||||
|
|
@ -164,6 +151,39 @@ func (p *PresenceStreamProvider) IncrementalSync(
|
|||
return lastPos
|
||||
}
|
||||
|
||||
func (p *PresenceStreamProvider) getNeededUsersFromRequest(ctx context.Context, req *types.SyncRequest, presences map[string]*types.PresenceInternal) ([]string, error) {
|
||||
getPresenceForUsers := []string{}
|
||||
// Add presence for users which newly joined a room
|
||||
for userID := range req.MembershipChanges {
|
||||
if _, ok := presences[userID]; ok {
|
||||
continue
|
||||
}
|
||||
getPresenceForUsers = append(getPresenceForUsers, userID)
|
||||
}
|
||||
|
||||
// add newly joined rooms user presences
|
||||
newlyJoined := joinedRooms(req.Response, req.Device.UserID)
|
||||
if len(newlyJoined) == 0 {
|
||||
return getPresenceForUsers, nil
|
||||
}
|
||||
|
||||
// TODO: Check if this is working better than before.
|
||||
if err := p.notifier.LoadRooms(ctx, p.DB, newlyJoined); err != nil {
|
||||
return getPresenceForUsers, fmt.Errorf("unable to refresh notifier lists: %w", err)
|
||||
}
|
||||
for _, roomID := range newlyJoined {
|
||||
roomUsers := p.notifier.JoinedUsers(roomID)
|
||||
for i := range roomUsers {
|
||||
// we already got a presence from this user
|
||||
if _, ok := presences[roomUsers[i]]; ok {
|
||||
continue
|
||||
}
|
||||
getPresenceForUsers = append(getPresenceForUsers, roomUsers[i])
|
||||
}
|
||||
}
|
||||
return getPresenceForUsers, nil
|
||||
}
|
||||
|
||||
func joinedRooms(res *types.Response, userID string) []string {
|
||||
var roomIDs []string
|
||||
for roomID, join := range res.Rooms.Join {
|
||||
|
|
|
|||
|
|
@ -145,12 +145,12 @@ func (rp *RequestPool) updatePresence(db storage.Presence, presence string, user
|
|||
}
|
||||
|
||||
// ensure we also send the current status_msg to federated servers and not nil
|
||||
dbPresence, err := db.GetPresence(context.Background(), userID)
|
||||
dbPresence, err := db.GetPresences(context.Background(), []string{userID})
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return
|
||||
}
|
||||
if dbPresence != nil {
|
||||
newPresence.ClientFields = dbPresence.ClientFields
|
||||
if len(dbPresence) > 0 && dbPresence[0] != nil {
|
||||
newPresence.ClientFields = dbPresence[0].ClientFields
|
||||
}
|
||||
newPresence.ClientFields.Presence = presenceID.String()
|
||||
|
||||
|
|
|
|||
|
|
@ -29,8 +29,8 @@ func (d dummyDB) UpdatePresence(ctx context.Context, userID string, presence typ
|
|||
return 0, nil
|
||||
}
|
||||
|
||||
func (d dummyDB) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) {
|
||||
return &types.PresenceInternal{}, nil
|
||||
func (d dummyDB) GetPresences(ctx context.Context, userID []string) ([]*types.PresenceInternal, error) {
|
||||
return []*types.PresenceInternal{}, nil
|
||||
}
|
||||
|
||||
func (d dummyDB) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) {
|
||||
|
|
|
|||
|
|
@ -48,4 +48,7 @@ If a device list update goes missing, the server resyncs on the next one
|
|||
Leaves are present in non-gapped incremental syncs
|
||||
|
||||
# Below test was passing for the wrong reason, failing correctly since #2858
|
||||
New federated private chats get full presence information (SYN-115)
|
||||
New federated private chats get full presence information (SYN-115)
|
||||
|
||||
# We don't have any state to calculate m.room.guest_access when accepting invites
|
||||
Guest users can accept invites to private rooms over federation
|
||||
|
|
@ -763,4 +763,7 @@ AS and main public room lists are separate
|
|||
local user has tags copied to the new room
|
||||
remote user has tags copied to the new room
|
||||
/upgrade moves remote aliases to the new room
|
||||
Local and remote users' homeservers remove a room from their public directory on upgrade
|
||||
Local and remote users' homeservers remove a room from their public directory on upgrade
|
||||
Guest users denied access over federation if guest access prohibited
|
||||
Guest users are kicked from guest_access rooms on revocation of guest_access
|
||||
Guest users are kicked from guest_access rooms on revocation of guest_access over federation
|
||||
22
test/room.go
22
test/room.go
|
|
@ -38,11 +38,12 @@ var (
|
|||
)
|
||||
|
||||
type Room struct {
|
||||
ID string
|
||||
Version gomatrixserverlib.RoomVersion
|
||||
preset Preset
|
||||
visibility gomatrixserverlib.HistoryVisibility
|
||||
creator *User
|
||||
ID string
|
||||
Version gomatrixserverlib.RoomVersion
|
||||
preset Preset
|
||||
guestCanJoin bool
|
||||
visibility gomatrixserverlib.HistoryVisibility
|
||||
creator *User
|
||||
|
||||
authEvents gomatrixserverlib.AuthEvents
|
||||
currentState map[string]*gomatrixserverlib.HeaderedEvent
|
||||
|
|
@ -120,6 +121,11 @@ func (r *Room) insertCreateEvents(t *testing.T) {
|
|||
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomPowerLevels, plContent, WithStateKey(""))
|
||||
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomJoinRules, joinRule, WithStateKey(""))
|
||||
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomHistoryVisibility, hisVis, WithStateKey(""))
|
||||
if r.guestCanJoin {
|
||||
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomGuestAccess, map[string]string{
|
||||
"guest_access": "can_join",
|
||||
}, WithStateKey(""))
|
||||
}
|
||||
}
|
||||
|
||||
// Create an event in this room but do not insert it. Does not modify the room in any way (depth, fwd extremities, etc) so is thread-safe.
|
||||
|
|
@ -268,3 +274,9 @@ func RoomVersion(ver gomatrixserverlib.RoomVersion) roomModifier {
|
|||
r.Version = ver
|
||||
}
|
||||
}
|
||||
|
||||
func GuestsCanJoin(canJoin bool) roomModifier {
|
||||
return func(t *testing.T, r *Room) {
|
||||
r.guestCanJoin = canJoin
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -108,7 +108,7 @@ func Base(cfg *config.Dendrite) (*base.BaseDendrite, nats.JetStreamContext, *nat
|
|||
cfg.Global.JetStream.InMemory = true
|
||||
cfg.SyncAPI.Fulltext.InMemory = true
|
||||
cfg.FederationAPI.KeyPerspectives = nil
|
||||
base := base.NewBaseDendrite(cfg, "Tests")
|
||||
base := base.NewBaseDendrite(cfg, "Tests", base.DisableMetrics)
|
||||
js, jc := base.NATS.Prepare(base.ProcessContext, &cfg.Global.JetStream)
|
||||
return base, js, jc
|
||||
}
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ var (
|
|||
|
||||
type User struct {
|
||||
ID string
|
||||
accountType api.AccountType
|
||||
AccountType api.AccountType
|
||||
// key ID and private key of the server who has this user, if known.
|
||||
keyID gomatrixserverlib.KeyID
|
||||
privKey ed25519.PrivateKey
|
||||
|
|
@ -66,7 +66,7 @@ func WithSigningServer(srvName gomatrixserverlib.ServerName, keyID gomatrixserve
|
|||
|
||||
func WithAccountType(accountType api.AccountType) UserOpt {
|
||||
return func(u *User) {
|
||||
u.accountType = accountType
|
||||
u.AccountType = accountType
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -50,6 +50,7 @@ type KeyserverUserAPI interface {
|
|||
|
||||
type RoomserverUserAPI interface {
|
||||
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
|
||||
QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) (err error)
|
||||
}
|
||||
|
||||
// api functions required by the media api
|
||||
|
|
@ -671,3 +672,12 @@ type PerformSaveThreePIDAssociationRequest struct {
|
|||
ServerName gomatrixserverlib.ServerName
|
||||
Medium string
|
||||
}
|
||||
|
||||
type QueryAccountByLocalpartRequest struct {
|
||||
Localpart string
|
||||
ServerName gomatrixserverlib.ServerName
|
||||
}
|
||||
|
||||
type QueryAccountByLocalpartResponse struct {
|
||||
Account *Account
|
||||
}
|
||||
|
|
|
|||
|
|
@ -204,6 +204,12 @@ func (t *UserInternalAPITrace) PerformSaveThreePIDAssociation(ctx context.Contex
|
|||
return err
|
||||
}
|
||||
|
||||
func (t *UserInternalAPITrace) QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) error {
|
||||
err := t.Impl.QueryAccountByLocalpart(ctx, req, res)
|
||||
util.GetLogger(ctx).Infof("QueryAccountByLocalpart req=%+v res=%+v", js(req), js(res))
|
||||
return err
|
||||
}
|
||||
|
||||
func js(thing interface{}) string {
|
||||
b, err := json.Marshal(thing)
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -548,6 +548,11 @@ func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAc
|
|||
return nil
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) QueryAccountByLocalpart(ctx context.Context, req *api.QueryAccountByLocalpartRequest, res *api.QueryAccountByLocalpartResponse) (err error) {
|
||||
res.Account, err = a.DB.GetAccountByLocalpart(ctx, req.Localpart, req.ServerName)
|
||||
return
|
||||
}
|
||||
|
||||
// Return the appservice 'device' or nil if the token is not an appservice. Returns an error if there was a problem
|
||||
// creating a 'device'.
|
||||
func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appServiceUserID string) (*api.Device, error) {
|
||||
|
|
|
|||
|
|
@ -60,6 +60,7 @@ const (
|
|||
QueryAccountByPasswordPath = "/userapi/queryAccountByPassword"
|
||||
QueryLocalpartForThreePIDPath = "/userapi/queryLocalpartForThreePID"
|
||||
QueryThreePIDsForLocalpartPath = "/userapi/queryThreePIDsForLocalpart"
|
||||
QueryAccountByLocalpartPath = "/userapi/queryAccountType"
|
||||
)
|
||||
|
||||
// NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API.
|
||||
|
|
@ -440,3 +441,14 @@ func (h *httpUserInternalAPI) PerformSaveThreePIDAssociation(
|
|||
h.httpClient, ctx, request, response,
|
||||
)
|
||||
}
|
||||
|
||||
func (h *httpUserInternalAPI) QueryAccountByLocalpart(
|
||||
ctx context.Context,
|
||||
req *api.QueryAccountByLocalpartRequest,
|
||||
res *api.QueryAccountByLocalpartResponse,
|
||||
) error {
|
||||
return httputil.CallInternalRPCAPI(
|
||||
"QueryAccountByLocalpart", h.apiURL+QueryAccountByLocalpartPath,
|
||||
h.httpClient, ctx, req, res,
|
||||
)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -189,4 +189,9 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI, enableMetrics
|
|||
PerformSaveThreePIDAssociationPath,
|
||||
httputil.MakeInternalRPCAPI("UserAPIPerformSaveThreePIDAssociation", enableMetrics, s.PerformSaveThreePIDAssociation),
|
||||
)
|
||||
|
||||
internalAPIMux.Handle(
|
||||
QueryAccountByLocalpartPath,
|
||||
httputil.MakeInternalRPCAPI("AccountByLocalpart", enableMetrics, s.QueryAccountByLocalpart),
|
||||
)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -307,3 +307,64 @@ func TestLoginToken(t *testing.T) {
|
|||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestQueryAccountByLocalpart(t *testing.T) {
|
||||
alice := test.NewUser(t)
|
||||
|
||||
localpart, userServername, _ := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
|
||||
ctx := context.Background()
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
intAPI, db, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType)
|
||||
defer close()
|
||||
|
||||
createdAcc, err := db.CreateAccount(ctx, localpart, userServername, "", "", alice.AccountType)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
testCases := func(t *testing.T, internalAPI api.UserInternalAPI) {
|
||||
// Query existing account
|
||||
queryAccResp := &api.QueryAccountByLocalpartResponse{}
|
||||
if err = internalAPI.QueryAccountByLocalpart(ctx, &api.QueryAccountByLocalpartRequest{
|
||||
Localpart: localpart,
|
||||
ServerName: userServername,
|
||||
}, queryAccResp); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if !reflect.DeepEqual(createdAcc, queryAccResp.Account) {
|
||||
t.Fatalf("created and queried accounts don't match:\n%+v vs.\n%+v", createdAcc, queryAccResp.Account)
|
||||
}
|
||||
|
||||
// Query non-existent account, this should result in an error
|
||||
err = internalAPI.QueryAccountByLocalpart(ctx, &api.QueryAccountByLocalpartRequest{
|
||||
Localpart: "doesnotexist",
|
||||
ServerName: userServername,
|
||||
}, queryAccResp)
|
||||
|
||||
if err == nil {
|
||||
t.Fatalf("expected an error, but got none: %+v", queryAccResp)
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("Monolith", func(t *testing.T) {
|
||||
testCases(t, intAPI)
|
||||
// also test tracing
|
||||
testCases(t, &api.UserInternalAPITrace{Impl: intAPI})
|
||||
})
|
||||
|
||||
t.Run("HTTP API", func(t *testing.T) {
|
||||
router := mux.NewRouter().PathPrefix(httputil.InternalPathPrefix).Subrouter()
|
||||
userapi.AddInternalRoutes(router, intAPI, false)
|
||||
apiURL, cancel := test.ListenAndServe(t, router, false)
|
||||
defer cancel()
|
||||
|
||||
userHTTPApi, err := inthttp.NewUserAPIClient(apiURL, &http.Client{Timeout: time.Second * 5})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create HTTP client: %s", err)
|
||||
}
|
||||
testCases(t, userHTTPApi)
|
||||
|
||||
})
|
||||
})
|
||||
}
|
||||
|
|
|
|||
119
userapi/util/notify_test.go
Normal file
119
userapi/util/notify_test.go
Normal file
|
|
@ -0,0 +1,119 @@
|
|||
package util_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/pushgateway"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/test"
|
||||
"github.com/matrix-org/dendrite/test/testrig"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage"
|
||||
userUtil "github.com/matrix-org/dendrite/userapi/util"
|
||||
)
|
||||
|
||||
func TestNotifyUserCountsAsync(t *testing.T) {
|
||||
alice := test.NewUser(t)
|
||||
aliceLocalpart, serverName, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a test room, just used to provide events
|
||||
room := test.NewRoom(t, alice)
|
||||
dummyEvent := room.Events()[len(room.Events())-1]
|
||||
|
||||
appID := util.RandomString(8)
|
||||
pushKey := util.RandomString(8)
|
||||
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
receivedRequest := make(chan bool, 1)
|
||||
// create a test server which responds to our /notify call
|
||||
srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var data pushgateway.NotifyRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&data); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
notification := data.Notification
|
||||
// Validate the request
|
||||
if notification.Counts == nil {
|
||||
t.Fatal("no unread notification counts in request")
|
||||
}
|
||||
if unread := notification.Counts.Unread; unread != 1 {
|
||||
t.Errorf("expected one unread notification, got %d", unread)
|
||||
}
|
||||
|
||||
if len(notification.Devices) == 0 {
|
||||
t.Fatal("expected devices in request")
|
||||
}
|
||||
|
||||
// We only created one push device, so access it directly
|
||||
device := notification.Devices[0]
|
||||
if device.AppID != appID {
|
||||
t.Errorf("unexpected app_id: %s, want %s", device.AppID, appID)
|
||||
}
|
||||
if device.PushKey != pushKey {
|
||||
t.Errorf("unexpected push_key: %s, want %s", device.PushKey, pushKey)
|
||||
}
|
||||
|
||||
// Return empty result, otherwise the call is handled as failed
|
||||
if _, err := w.Write([]byte("{}")); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
close(receivedRequest)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
// Create DB and Dendrite base
|
||||
connStr, close := test.PrepareDBConnectionString(t, dbType)
|
||||
defer close()
|
||||
base, _, _ := testrig.Base(nil)
|
||||
defer base.Close()
|
||||
db, err := storage.NewUserAPIDatabase(base, &config.DatabaseOptions{
|
||||
ConnectionString: config.DataSource(connStr),
|
||||
}, "test", bcrypt.MinCost, 0, 0, "")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
// Prepare pusher with our test server URL
|
||||
if err := db.UpsertPusher(ctx, api.Pusher{
|
||||
Kind: api.HTTPKind,
|
||||
AppID: appID,
|
||||
PushKey: pushKey,
|
||||
Data: map[string]interface{}{
|
||||
"url": srv.URL,
|
||||
},
|
||||
}, aliceLocalpart, serverName); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
// Insert a dummy event
|
||||
if err := db.InsertNotification(ctx, aliceLocalpart, serverName, dummyEvent.EventID(), 0, nil, &api.Notification{
|
||||
Event: gomatrixserverlib.HeaderedToClientEvent(dummyEvent, gomatrixserverlib.FormatAll),
|
||||
}); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
// Notify the user about a new notification
|
||||
if err := userUtil.NotifyUserCountsAsync(ctx, pushgateway.NewHTTPClient(true), aliceLocalpart, serverName, db); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
select {
|
||||
case <-time.After(time.Second * 5):
|
||||
t.Error("timed out waiting for response")
|
||||
case <-receivedRequest:
|
||||
}
|
||||
})
|
||||
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue