mirror of
https://github.com/matrix-org/dendrite.git
synced 2026-01-21 13:03:09 -06:00
Merge branch 'main' of github.com:matrix-org/dendrite into neilalexander/purgeroom
This commit is contained in:
commit
2b7d1023ba
134
clientapi/admin_test.go
Normal file
134
clientapi/admin_test.go
Normal file
|
|
@ -0,0 +1,134 @@
|
||||||
|
package clientapi
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
|
"github.com/matrix-org/dendrite/keyserver"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver"
|
||||||
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/matrix-org/util"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/test"
|
||||||
|
"github.com/matrix-org/dendrite/test/testrig"
|
||||||
|
"github.com/matrix-org/dendrite/userapi"
|
||||||
|
uapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAdminResetPassword(t *testing.T) {
|
||||||
|
aliceAdmin := test.NewUser(t, test.WithAccountType(uapi.AccountTypeAdmin))
|
||||||
|
bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser))
|
||||||
|
vhUser := &test.User{ID: "@vhuser:vh1"}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
base, baseClose := testrig.CreateBaseDendrite(t, dbType)
|
||||||
|
defer baseClose()
|
||||||
|
|
||||||
|
// add a vhost
|
||||||
|
base.Cfg.Global.VirtualHosts = append(base.Cfg.Global.VirtualHosts, &config.VirtualHost{
|
||||||
|
SigningIdentity: gomatrixserverlib.SigningIdentity{ServerName: "vh1"},
|
||||||
|
})
|
||||||
|
|
||||||
|
rsAPI := roomserver.NewInternalAPI(base)
|
||||||
|
// Needed for changing the password/login
|
||||||
|
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI)
|
||||||
|
userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil)
|
||||||
|
keyAPI.SetUserAPI(userAPI)
|
||||||
|
// We mostly need the userAPI for this test, so nil for other APIs/caches etc.
|
||||||
|
AddPublicRoutes(base, nil, nil, nil, nil, nil, userAPI, nil, nil, nil)
|
||||||
|
|
||||||
|
// Create the users in the userapi and login
|
||||||
|
accessTokens := map[*test.User]string{
|
||||||
|
aliceAdmin: "",
|
||||||
|
bob: "",
|
||||||
|
vhUser: "",
|
||||||
|
}
|
||||||
|
for u := range accessTokens {
|
||||||
|
localpart, serverName, _ := gomatrixserverlib.SplitID('@', u.ID)
|
||||||
|
userRes := &uapi.PerformAccountCreationResponse{}
|
||||||
|
password := util.RandomString(8)
|
||||||
|
if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{
|
||||||
|
AccountType: u.AccountType,
|
||||||
|
Localpart: localpart,
|
||||||
|
ServerName: serverName,
|
||||||
|
Password: password,
|
||||||
|
}, userRes); err != nil {
|
||||||
|
t.Errorf("failed to create account: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := test.NewRequest(t, http.MethodPost, "/_matrix/client/v3/login", test.WithJSONBody(t, map[string]interface{}{
|
||||||
|
"type": authtypes.LoginTypePassword,
|
||||||
|
"identifier": map[string]interface{}{
|
||||||
|
"type": "m.id.user",
|
||||||
|
"user": u.ID,
|
||||||
|
},
|
||||||
|
"password": password,
|
||||||
|
}))
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
base.PublicClientAPIMux.ServeHTTP(rec, req)
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("failed to login: %s", rec.Body.String())
|
||||||
|
}
|
||||||
|
accessTokens[u] = gjson.GetBytes(rec.Body.Bytes(), "access_token").String()
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
requestingUser *test.User
|
||||||
|
userID string
|
||||||
|
requestOpt test.HTTPRequestOpt
|
||||||
|
wantOK bool
|
||||||
|
withHeader bool
|
||||||
|
}{
|
||||||
|
{name: "Missing auth", requestingUser: bob, wantOK: false, userID: bob.ID},
|
||||||
|
{name: "Bob is denied access", requestingUser: bob, wantOK: false, withHeader: true, userID: bob.ID},
|
||||||
|
{name: "Alice is allowed access", requestingUser: aliceAdmin, wantOK: true, withHeader: true, userID: bob.ID, requestOpt: test.WithJSONBody(t, map[string]interface{}{
|
||||||
|
"password": util.RandomString(8),
|
||||||
|
})},
|
||||||
|
{name: "missing userID does not call function", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: ""}, // this 404s
|
||||||
|
{name: "rejects empty password", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: bob.ID, requestOpt: test.WithJSONBody(t, map[string]interface{}{
|
||||||
|
"password": "",
|
||||||
|
})},
|
||||||
|
{name: "rejects unknown server name", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: "@doesnotexist:localhost", requestOpt: test.WithJSONBody(t, map[string]interface{}{})},
|
||||||
|
{name: "rejects unknown user", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: "@doesnotexist:test", requestOpt: test.WithJSONBody(t, map[string]interface{}{})},
|
||||||
|
{name: "allows changing password for different vhost", requestingUser: aliceAdmin, wantOK: true, withHeader: true, userID: vhUser.ID, requestOpt: test.WithJSONBody(t, map[string]interface{}{
|
||||||
|
"password": util.RandomString(8),
|
||||||
|
})},
|
||||||
|
{name: "rejects existing user, missing body", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: bob.ID},
|
||||||
|
{name: "rejects invalid userID", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: "!notauserid:test", requestOpt: test.WithJSONBody(t, map[string]interface{}{})},
|
||||||
|
{name: "rejects invalid json", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: bob.ID, requestOpt: test.WithJSONBody(t, `{invalidJSON}`)},
|
||||||
|
{name: "rejects too weak password", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: bob.ID, requestOpt: test.WithJSONBody(t, map[string]interface{}{
|
||||||
|
"password": util.RandomString(6),
|
||||||
|
})},
|
||||||
|
{name: "rejects too long password", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: bob.ID, requestOpt: test.WithJSONBody(t, map[string]interface{}{
|
||||||
|
"password": util.RandomString(513),
|
||||||
|
})},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
req := test.NewRequest(t, http.MethodPost, "/_dendrite/admin/resetPassword/"+tc.userID)
|
||||||
|
if tc.requestOpt != nil {
|
||||||
|
req = test.NewRequest(t, http.MethodPost, "/_dendrite/admin/resetPassword/"+tc.userID, tc.requestOpt)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tc.withHeader {
|
||||||
|
req.Header.Set("Authorization", "Bearer "+accessTokens[tc.requestingUser])
|
||||||
|
}
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
base.DendriteAdminMux.ServeHTTP(rec, req)
|
||||||
|
t.Logf("%s", rec.Body.String())
|
||||||
|
if tc.wantOK && rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected http status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
@ -15,6 +15,8 @@
|
||||||
package clientapi
|
package clientapi
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
|
||||||
appserviceAPI "github.com/matrix-org/dendrite/appservice/api"
|
appserviceAPI "github.com/matrix-org/dendrite/appservice/api"
|
||||||
"github.com/matrix-org/dendrite/clientapi/api"
|
"github.com/matrix-org/dendrite/clientapi/api"
|
||||||
"github.com/matrix-org/dendrite/clientapi/producers"
|
"github.com/matrix-org/dendrite/clientapi/producers"
|
||||||
|
|
@ -26,7 +28,6 @@ import (
|
||||||
"github.com/matrix-org/dendrite/setup/base"
|
"github.com/matrix-org/dendrite/setup/base"
|
||||||
"github.com/matrix-org/dendrite/setup/jetstream"
|
"github.com/matrix-org/dendrite/setup/jetstream"
|
||||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// AddPublicRoutes sets up and registers HTTP handlers for the ClientAPI component.
|
// AddPublicRoutes sets up and registers HTTP handlers for the ClientAPI component.
|
||||||
|
|
@ -57,10 +58,7 @@ func AddPublicRoutes(
|
||||||
}
|
}
|
||||||
|
|
||||||
routing.Setup(
|
routing.Setup(
|
||||||
base.PublicClientAPIMux,
|
base,
|
||||||
base.PublicWellKnownAPIMux,
|
|
||||||
base.SynapseAdminMux,
|
|
||||||
base.DendriteAdminMux,
|
|
||||||
cfg, rsAPI, asAPI,
|
cfg, rsAPI, asAPI,
|
||||||
userAPI, userDirectoryProvider, federation,
|
userAPI, userDirectoryProvider, federation,
|
||||||
syncProducer, transactionsCache, fsAPI, keyAPI,
|
syncProducer, transactionsCache, fsAPI, keyAPI,
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
|
"github.com/matrix-org/dendrite/internal"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
"github.com/nats-io/nats.go"
|
"github.com/nats-io/nats.go"
|
||||||
|
|
@ -130,20 +131,40 @@ func AdminPurgeRoom(req *http.Request, cfg *config.ClientAPI, device *userapi.De
|
||||||
}
|
}
|
||||||
|
|
||||||
func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse {
|
func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse {
|
||||||
|
if req.Body == nil {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: jsonerror.Unknown("Missing request body"),
|
||||||
|
}
|
||||||
|
}
|
||||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return util.ErrorResponse(err)
|
return util.ErrorResponse(err)
|
||||||
}
|
}
|
||||||
serverName := cfg.Matrix.ServerName
|
var localpart string
|
||||||
localpart, ok := vars["localpart"]
|
userID := vars["userID"]
|
||||||
if !ok {
|
localpart, serverName, err := cfg.Matrix.SplitLocalID('@', userID)
|
||||||
|
if err != nil {
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusBadRequest,
|
Code: http.StatusBadRequest,
|
||||||
JSON: jsonerror.MissingArgument("Expecting user localpart."),
|
JSON: jsonerror.InvalidArgumentValue(err.Error()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if l, s, err := cfg.Matrix.SplitLocalID('@', localpart); err == nil {
|
accAvailableResp := &userapi.QueryAccountAvailabilityResponse{}
|
||||||
localpart, serverName = l, s
|
if err = userAPI.QueryAccountAvailability(req.Context(), &userapi.QueryAccountAvailabilityRequest{
|
||||||
|
Localpart: localpart,
|
||||||
|
ServerName: serverName,
|
||||||
|
}, accAvailableResp); err != nil {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusInternalServerError,
|
||||||
|
JSON: jsonerror.InternalAPIError(req.Context(), err),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if accAvailableResp.Available {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusNotFound,
|
||||||
|
JSON: jsonerror.Unknown("User does not exist"),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
request := struct {
|
request := struct {
|
||||||
Password string `json:"password"`
|
Password string `json:"password"`
|
||||||
|
|
@ -160,6 +181,11 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap
|
||||||
JSON: jsonerror.MissingArgument("Expecting non-empty password."),
|
JSON: jsonerror.MissingArgument("Expecting non-empty password."),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if resErr := internal.ValidatePassword(request.Password); resErr != nil {
|
||||||
|
return *resErr
|
||||||
|
}
|
||||||
|
|
||||||
updateReq := &userapi.PerformPasswordUpdateRequest{
|
updateReq := &userapi.PerformPasswordUpdateRequest{
|
||||||
Localpart: localpart,
|
Localpart: localpart,
|
||||||
ServerName: serverName,
|
ServerName: serverName,
|
||||||
|
|
|
||||||
|
|
@ -37,6 +37,7 @@ func JoinRoomByIDOrAlias(
|
||||||
joinReq := roomserverAPI.PerformJoinRequest{
|
joinReq := roomserverAPI.PerformJoinRequest{
|
||||||
RoomIDOrAlias: roomIDOrAlias,
|
RoomIDOrAlias: roomIDOrAlias,
|
||||||
UserID: device.UserID,
|
UserID: device.UserID,
|
||||||
|
IsGuest: device.AccountType == api.AccountTypeGuest,
|
||||||
Content: map[string]interface{}{},
|
Content: map[string]interface{}{},
|
||||||
}
|
}
|
||||||
joinRes := roomserverAPI.PerformJoinResponse{}
|
joinRes := roomserverAPI.PerformJoinResponse{}
|
||||||
|
|
@ -84,7 +85,14 @@ func JoinRoomByIDOrAlias(
|
||||||
if err := rsAPI.PerformJoin(req.Context(), &joinReq, &joinRes); err != nil {
|
if err := rsAPI.PerformJoin(req.Context(), &joinReq, &joinRes); err != nil {
|
||||||
done <- jsonerror.InternalAPIError(req.Context(), err)
|
done <- jsonerror.InternalAPIError(req.Context(), err)
|
||||||
} else if joinRes.Error != nil {
|
} else if joinRes.Error != nil {
|
||||||
done <- joinRes.Error.JSONResponse()
|
if joinRes.Error.Code == roomserverAPI.PerformErrorNotAllowed && device.AccountType == api.AccountTypeGuest {
|
||||||
|
done <- util.JSONResponse{
|
||||||
|
Code: http.StatusForbidden,
|
||||||
|
JSON: jsonerror.GuestAccessForbidden(joinRes.Error.Msg),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
done <- joinRes.Error.JSONResponse()
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
done <- util.JSONResponse{
|
done <- util.JSONResponse{
|
||||||
Code: http.StatusOK,
|
Code: http.StatusOK,
|
||||||
|
|
|
||||||
158
clientapi/routing/joinroom_test.go
Normal file
158
clientapi/routing/joinroom_test.go
Normal file
|
|
@ -0,0 +1,158 @@
|
||||||
|
package routing
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/appservice"
|
||||||
|
"github.com/matrix-org/dendrite/keyserver"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver"
|
||||||
|
"github.com/matrix-org/dendrite/test"
|
||||||
|
"github.com/matrix-org/dendrite/test/testrig"
|
||||||
|
"github.com/matrix-org/dendrite/userapi"
|
||||||
|
uapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestJoinRoomByIDOrAlias(t *testing.T) {
|
||||||
|
alice := test.NewUser(t)
|
||||||
|
bob := test.NewUser(t)
|
||||||
|
charlie := test.NewUser(t, test.WithAccountType(uapi.AccountTypeGuest))
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
base, baseClose := testrig.CreateBaseDendrite(t, dbType)
|
||||||
|
defer baseClose()
|
||||||
|
|
||||||
|
rsAPI := roomserver.NewInternalAPI(base)
|
||||||
|
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI)
|
||||||
|
userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil)
|
||||||
|
asAPI := appservice.NewInternalAPI(base, userAPI, rsAPI)
|
||||||
|
rsAPI.SetFederationAPI(nil, nil) // creates the rs.Inputer etc
|
||||||
|
|
||||||
|
// Create the users in the userapi
|
||||||
|
for _, u := range []*test.User{alice, bob, charlie} {
|
||||||
|
localpart, serverName, _ := gomatrixserverlib.SplitID('@', u.ID)
|
||||||
|
userRes := &uapi.PerformAccountCreationResponse{}
|
||||||
|
if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{
|
||||||
|
AccountType: u.AccountType,
|
||||||
|
Localpart: localpart,
|
||||||
|
ServerName: serverName,
|
||||||
|
Password: "someRandomPassword",
|
||||||
|
}, userRes); err != nil {
|
||||||
|
t.Errorf("failed to create account: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
aliceDev := &uapi.Device{UserID: alice.ID}
|
||||||
|
bobDev := &uapi.Device{UserID: bob.ID}
|
||||||
|
charlieDev := &uapi.Device{UserID: charlie.ID, AccountType: uapi.AccountTypeGuest}
|
||||||
|
|
||||||
|
// create a room with disabled guest access and invite Bob
|
||||||
|
resp := createRoom(ctx, createRoomRequest{
|
||||||
|
Name: "testing",
|
||||||
|
IsDirect: true,
|
||||||
|
Topic: "testing",
|
||||||
|
Visibility: "public",
|
||||||
|
Preset: presetPublicChat,
|
||||||
|
RoomAliasName: "alias",
|
||||||
|
Invite: []string{bob.ID},
|
||||||
|
GuestCanJoin: false,
|
||||||
|
}, aliceDev, &base.Cfg.ClientAPI, userAPI, rsAPI, asAPI, time.Now())
|
||||||
|
crResp, ok := resp.JSON.(createRoomResponse)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("response is not a createRoomResponse: %+v", resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
// create a room with guest access enabled and invite Charlie
|
||||||
|
resp = createRoom(ctx, createRoomRequest{
|
||||||
|
Name: "testing",
|
||||||
|
IsDirect: true,
|
||||||
|
Topic: "testing",
|
||||||
|
Visibility: "public",
|
||||||
|
Preset: presetPublicChat,
|
||||||
|
Invite: []string{charlie.ID},
|
||||||
|
GuestCanJoin: true,
|
||||||
|
}, aliceDev, &base.Cfg.ClientAPI, userAPI, rsAPI, asAPI, time.Now())
|
||||||
|
crRespWithGuestAccess, ok := resp.JSON.(createRoomResponse)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("response is not a createRoomResponse: %+v", resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dummy request
|
||||||
|
body := &bytes.Buffer{}
|
||||||
|
req, err := http.NewRequest(http.MethodPost, "/?server_name=test", body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
device *uapi.Device
|
||||||
|
roomID string
|
||||||
|
wantHTTP200 bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "User can join successfully by alias",
|
||||||
|
device: bobDev,
|
||||||
|
roomID: crResp.RoomAlias,
|
||||||
|
wantHTTP200: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "User can join successfully by roomID",
|
||||||
|
device: bobDev,
|
||||||
|
roomID: crResp.RoomID,
|
||||||
|
wantHTTP200: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "join is forbidden if user is guest",
|
||||||
|
device: charlieDev,
|
||||||
|
roomID: crResp.RoomID,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "room does not exist",
|
||||||
|
device: aliceDev,
|
||||||
|
roomID: "!doesnotexist:test",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "user from different server",
|
||||||
|
device: &uapi.Device{UserID: "@wrong:server"},
|
||||||
|
roomID: crResp.RoomAlias,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "user doesn't exist locally",
|
||||||
|
device: &uapi.Device{UserID: "@doesnotexist:test"},
|
||||||
|
roomID: crResp.RoomAlias,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid room ID",
|
||||||
|
device: aliceDev,
|
||||||
|
roomID: "invalidRoomID",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "roomAlias does not exist",
|
||||||
|
device: aliceDev,
|
||||||
|
roomID: "#doesnotexist:test",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "room with guest_access event",
|
||||||
|
device: charlieDev,
|
||||||
|
roomID: crRespWithGuestAccess.RoomID,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
joinResp := JoinRoomByIDOrAlias(req, tc.device, rsAPI, userAPI, tc.roomID)
|
||||||
|
if tc.wantHTTP200 && !joinResp.Is2xx() {
|
||||||
|
t.Fatalf("expected join room to succeed, but didn't: %+v", joinResp)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
@ -7,6 +7,7 @@ import (
|
||||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
"github.com/matrix-org/dendrite/clientapi/httputil"
|
"github.com/matrix-org/dendrite/clientapi/httputil"
|
||||||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||||
|
"github.com/matrix-org/dendrite/internal"
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
"github.com/matrix-org/dendrite/userapi/api"
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
|
@ -81,7 +82,7 @@ func Password(
|
||||||
sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword)
|
sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword)
|
||||||
|
|
||||||
// Check the new password strength.
|
// Check the new password strength.
|
||||||
if resErr = validatePassword(r.NewPassword); resErr != nil {
|
if resErr = internal.ValidatePassword(r.NewPassword); resErr != nil {
|
||||||
return *resErr
|
return *resErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -30,6 +30,7 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/eventutil"
|
"github.com/matrix-org/dendrite/internal/eventutil"
|
||||||
|
|
@ -60,8 +61,6 @@ var (
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
minPasswordLength = 8 // http://matrix.org/docs/spec/client_server/r0.2.0.html#password-based
|
|
||||||
maxPasswordLength = 512 // https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161
|
|
||||||
maxUsernameLength = 254 // http://matrix.org/speculator/spec/HEAD/intro.html#user-identifiers TODO account for domain
|
maxUsernameLength = 254 // http://matrix.org/speculator/spec/HEAD/intro.html#user-identifiers TODO account for domain
|
||||||
sessionIDLength = 24
|
sessionIDLength = 24
|
||||||
)
|
)
|
||||||
|
|
@ -315,23 +314,6 @@ func validateApplicationServiceUsername(localpart string, domain gomatrixserverl
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// validatePassword returns an error response if the password is invalid
|
|
||||||
func validatePassword(password string) *util.JSONResponse {
|
|
||||||
// https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161
|
|
||||||
if len(password) > maxPasswordLength {
|
|
||||||
return &util.JSONResponse{
|
|
||||||
Code: http.StatusBadRequest,
|
|
||||||
JSON: jsonerror.BadJSON(fmt.Sprintf("'password' >%d characters", maxPasswordLength)),
|
|
||||||
}
|
|
||||||
} else if len(password) > 0 && len(password) < minPasswordLength {
|
|
||||||
return &util.JSONResponse{
|
|
||||||
Code: http.StatusBadRequest,
|
|
||||||
JSON: jsonerror.WeakPassword(fmt.Sprintf("password too weak: min %d chars", minPasswordLength)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// validateRecaptcha returns an error response if the captcha response is invalid
|
// validateRecaptcha returns an error response if the captcha response is invalid
|
||||||
func validateRecaptcha(
|
func validateRecaptcha(
|
||||||
cfg *config.ClientAPI,
|
cfg *config.ClientAPI,
|
||||||
|
|
@ -636,7 +618,7 @@ func Register(
|
||||||
return *resErr
|
return *resErr
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if resErr := validatePassword(r.Password); resErr != nil {
|
if resErr := internal.ValidatePassword(r.Password); resErr != nil {
|
||||||
return *resErr
|
return *resErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1138,7 +1120,7 @@ func handleSharedSecretRegistration(cfg *config.ClientAPI, userAPI userapi.Clien
|
||||||
if resErr := validateUsername(ssrr.User, cfg.Matrix.ServerName); resErr != nil {
|
if resErr := validateUsername(ssrr.User, cfg.Matrix.ServerName); resErr != nil {
|
||||||
return *resErr
|
return *resErr
|
||||||
}
|
}
|
||||||
if resErr := validatePassword(ssrr.Password); resErr != nil {
|
if resErr := internal.ValidatePassword(ssrr.Password); resErr != nil {
|
||||||
return *resErr
|
return *resErr
|
||||||
}
|
}
|
||||||
deviceID := "shared_secret_registration"
|
deviceID := "shared_secret_registration"
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,7 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
|
"github.com/matrix-org/dendrite/setup/base"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
"github.com/nats-io/nats.go"
|
"github.com/nats-io/nats.go"
|
||||||
|
|
@ -49,7 +50,7 @@ import (
|
||||||
// applied:
|
// applied:
|
||||||
// nolint: gocyclo
|
// nolint: gocyclo
|
||||||
func Setup(
|
func Setup(
|
||||||
publicAPIMux, wkMux, synapseAdminRouter, dendriteAdminRouter *mux.Router,
|
base *base.BaseDendrite,
|
||||||
cfg *config.ClientAPI,
|
cfg *config.ClientAPI,
|
||||||
rsAPI roomserverAPI.ClientRoomserverAPI,
|
rsAPI roomserverAPI.ClientRoomserverAPI,
|
||||||
asAPI appserviceAPI.AppServiceInternalAPI,
|
asAPI appserviceAPI.AppServiceInternalAPI,
|
||||||
|
|
@ -63,7 +64,14 @@ func Setup(
|
||||||
extRoomsProvider api.ExtraPublicRoomsProvider,
|
extRoomsProvider api.ExtraPublicRoomsProvider,
|
||||||
mscCfg *config.MSCs, natsClient *nats.Conn,
|
mscCfg *config.MSCs, natsClient *nats.Conn,
|
||||||
) {
|
) {
|
||||||
prometheus.MustRegister(amtRegUsers, sendEventDuration)
|
publicAPIMux := base.PublicClientAPIMux
|
||||||
|
wkMux := base.PublicWellKnownAPIMux
|
||||||
|
synapseAdminRouter := base.SynapseAdminMux
|
||||||
|
dendriteAdminRouter := base.DendriteAdminMux
|
||||||
|
|
||||||
|
if base.EnableMetrics {
|
||||||
|
prometheus.MustRegister(amtRegUsers, sendEventDuration)
|
||||||
|
}
|
||||||
|
|
||||||
rateLimits := httputil.NewRateLimits(&cfg.RateLimiting)
|
rateLimits := httputil.NewRateLimits(&cfg.RateLimiting)
|
||||||
userInteractiveAuth := auth.NewUserInteractive(userAPI, cfg)
|
userInteractiveAuth := auth.NewUserInteractive(userAPI, cfg)
|
||||||
|
|
@ -637,7 +645,7 @@ func Setup(
|
||||||
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
|
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
|
||||||
|
|
||||||
v3mux.Handle("/auth/{authType}/fallback/web",
|
v3mux.Handle("/auth/{authType}/fallback/web",
|
||||||
httputil.MakeHTMLAPI("auth_fallback", func(w http.ResponseWriter, req *http.Request) *util.JSONResponse {
|
httputil.MakeHTMLAPI("auth_fallback", base.EnableMetrics, func(w http.ResponseWriter, req *http.Request) *util.JSONResponse {
|
||||||
vars := mux.Vars(req)
|
vars := mux.Vars(req)
|
||||||
return AuthFallback(w, req, vars["authType"], cfg)
|
return AuthFallback(w, req, vars["authType"], cfg)
|
||||||
}),
|
}),
|
||||||
|
|
|
||||||
|
|
@ -44,7 +44,9 @@ This endpoint will instruct Dendrite to part the given local `userID` in the URL
|
||||||
all rooms which they are currently joined. A JSON body will be returned containing
|
all rooms which they are currently joined. A JSON body will be returned containing
|
||||||
the room IDs of all affected rooms.
|
the room IDs of all affected rooms.
|
||||||
|
|
||||||
## POST `/_dendrite/admin/resetPassword/{localpart}`
|
## POST `/_dendrite/admin/resetPassword/{userID}`
|
||||||
|
|
||||||
|
Reset the password of a local user.
|
||||||
|
|
||||||
Request body format:
|
Request body format:
|
||||||
|
|
||||||
|
|
@ -54,9 +56,6 @@ Request body format:
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
Reset the password of a local user. The `localpart` is the username only, i.e. if
|
|
||||||
the full user ID is `@alice:domain.com` then the local part is `alice`.
|
|
||||||
|
|
||||||
## GET `/_dendrite/admin/fulltext/reindex`
|
## GET `/_dendrite/admin/fulltext/reindex`
|
||||||
|
|
||||||
This endpoint instructs Dendrite to reindex all searchable events (`m.room.message`, `m.room.topic` and `m.room.name`). An empty JSON body will be returned immediately.
|
This endpoint instructs Dendrite to reindex all searchable events (`m.room.message`, `m.room.topic` and `m.room.name`). An empty JSON body will be returned immediately.
|
||||||
|
|
|
||||||
|
|
@ -221,28 +221,6 @@ func (d *fakeDatabase) CleanEDUs(ctx context.Context, serverName gomatrixserverl
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *fakeDatabase) GetPendingPDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error) {
|
|
||||||
d.dbMutex.Lock()
|
|
||||||
defer d.dbMutex.Unlock()
|
|
||||||
|
|
||||||
var count int64
|
|
||||||
if pdus, ok := d.associatedPDUs[serverName]; ok {
|
|
||||||
count = int64(len(pdus))
|
|
||||||
}
|
|
||||||
return count, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *fakeDatabase) GetPendingEDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error) {
|
|
||||||
d.dbMutex.Lock()
|
|
||||||
defer d.dbMutex.Unlock()
|
|
||||||
|
|
||||||
var count int64
|
|
||||||
if edus, ok := d.associatedEDUs[serverName]; ok {
|
|
||||||
count = int64(len(edus))
|
|
||||||
}
|
|
||||||
return count, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *fakeDatabase) GetPendingPDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) {
|
func (d *fakeDatabase) GetPendingPDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) {
|
||||||
d.dbMutex.Lock()
|
d.dbMutex.Lock()
|
||||||
defer d.dbMutex.Unlock()
|
defer d.dbMutex.Unlock()
|
||||||
|
|
|
||||||
|
|
@ -45,9 +45,6 @@ type Database interface {
|
||||||
CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error
|
CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error
|
||||||
CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error
|
CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error
|
||||||
|
|
||||||
GetPendingPDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error)
|
|
||||||
GetPendingEDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error)
|
|
||||||
|
|
||||||
GetPendingPDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
|
GetPendingPDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
|
||||||
GetPendingEDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
|
GetPendingEDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -44,13 +44,13 @@ const selectInboundPeekSQL = "" +
|
||||||
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3"
|
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3"
|
||||||
|
|
||||||
const selectInboundPeeksSQL = "" +
|
const selectInboundPeeksSQL = "" +
|
||||||
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1"
|
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1 ORDER by creation_ts"
|
||||||
|
|
||||||
const renewInboundPeekSQL = "" +
|
const renewInboundPeekSQL = "" +
|
||||||
"UPDATE federationsender_inbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5"
|
"UPDATE federationsender_inbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5"
|
||||||
|
|
||||||
const deleteInboundPeekSQL = "" +
|
const deleteInboundPeekSQL = "" +
|
||||||
"DELETE FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2"
|
"DELETE FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3"
|
||||||
|
|
||||||
const deleteInboundPeeksSQL = "" +
|
const deleteInboundPeeksSQL = "" +
|
||||||
"DELETE FROM federationsender_inbound_peeks WHERE room_id = $1"
|
"DELETE FROM federationsender_inbound_peeks WHERE room_id = $1"
|
||||||
|
|
@ -74,25 +74,15 @@ func NewPostgresInboundPeeksTable(db *sql.DB) (s *inboundPeeksStatements, err er
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.insertInboundPeekStmt, err = db.Prepare(insertInboundPeekSQL); err != nil {
|
return s, sqlutil.StatementList{
|
||||||
return
|
{&s.insertInboundPeekStmt, insertInboundPeekSQL},
|
||||||
}
|
{&s.selectInboundPeekStmt, selectInboundPeekSQL},
|
||||||
if s.selectInboundPeekStmt, err = db.Prepare(selectInboundPeekSQL); err != nil {
|
{&s.selectInboundPeekStmt, selectInboundPeekSQL},
|
||||||
return
|
{&s.selectInboundPeeksStmt, selectInboundPeeksSQL},
|
||||||
}
|
{&s.renewInboundPeekStmt, renewInboundPeekSQL},
|
||||||
if s.selectInboundPeeksStmt, err = db.Prepare(selectInboundPeeksSQL); err != nil {
|
{&s.deleteInboundPeeksStmt, deleteInboundPeeksSQL},
|
||||||
return
|
{&s.deleteInboundPeekStmt, deleteInboundPeekSQL},
|
||||||
}
|
}.Prepare(db)
|
||||||
if s.renewInboundPeekStmt, err = db.Prepare(renewInboundPeekSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.deleteInboundPeeksStmt, err = db.Prepare(deleteInboundPeeksSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.deleteInboundPeekStmt, err = db.Prepare(deleteInboundPeekSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *inboundPeeksStatements) InsertInboundPeek(
|
func (s *inboundPeeksStatements) InsertInboundPeek(
|
||||||
|
|
|
||||||
|
|
@ -44,13 +44,13 @@ const selectOutboundPeekSQL = "" +
|
||||||
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3"
|
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3"
|
||||||
|
|
||||||
const selectOutboundPeeksSQL = "" +
|
const selectOutboundPeeksSQL = "" +
|
||||||
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1"
|
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1 ORDER BY creation_ts"
|
||||||
|
|
||||||
const renewOutboundPeekSQL = "" +
|
const renewOutboundPeekSQL = "" +
|
||||||
"UPDATE federationsender_outbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5"
|
"UPDATE federationsender_outbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5"
|
||||||
|
|
||||||
const deleteOutboundPeekSQL = "" +
|
const deleteOutboundPeekSQL = "" +
|
||||||
"DELETE FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2"
|
"DELETE FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3"
|
||||||
|
|
||||||
const deleteOutboundPeeksSQL = "" +
|
const deleteOutboundPeeksSQL = "" +
|
||||||
"DELETE FROM federationsender_outbound_peeks WHERE room_id = $1"
|
"DELETE FROM federationsender_outbound_peeks WHERE room_id = $1"
|
||||||
|
|
@ -74,25 +74,14 @@ func NewPostgresOutboundPeeksTable(db *sql.DB) (s *outboundPeeksStatements, err
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.insertOutboundPeekStmt, err = db.Prepare(insertOutboundPeekSQL); err != nil {
|
return s, sqlutil.StatementList{
|
||||||
return
|
{&s.insertOutboundPeekStmt, insertOutboundPeekSQL},
|
||||||
}
|
{&s.selectOutboundPeekStmt, selectOutboundPeekSQL},
|
||||||
if s.selectOutboundPeekStmt, err = db.Prepare(selectOutboundPeekSQL); err != nil {
|
{&s.selectOutboundPeeksStmt, selectOutboundPeeksSQL},
|
||||||
return
|
{&s.renewOutboundPeekStmt, renewOutboundPeekSQL},
|
||||||
}
|
{&s.deleteOutboundPeeksStmt, deleteOutboundPeeksSQL},
|
||||||
if s.selectOutboundPeeksStmt, err = db.Prepare(selectOutboundPeeksSQL); err != nil {
|
{&s.deleteOutboundPeekStmt, deleteOutboundPeekSQL},
|
||||||
return
|
}.Prepare(db)
|
||||||
}
|
|
||||||
if s.renewOutboundPeekStmt, err = db.Prepare(renewOutboundPeekSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.deleteOutboundPeeksStmt, err = db.Prepare(deleteOutboundPeeksSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.deleteOutboundPeekStmt, err = db.Prepare(deleteOutboundPeekSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *outboundPeeksStatements) InsertOutboundPeek(
|
func (s *outboundPeeksStatements) InsertOutboundPeek(
|
||||||
|
|
|
||||||
|
|
@ -62,10 +62,6 @@ const selectQueueEDUReferenceJSONCountSQL = "" +
|
||||||
"SELECT COUNT(*) FROM federationsender_queue_edus" +
|
"SELECT COUNT(*) FROM federationsender_queue_edus" +
|
||||||
" WHERE json_nid = $1"
|
" WHERE json_nid = $1"
|
||||||
|
|
||||||
const selectQueueEDUCountSQL = "" +
|
|
||||||
"SELECT COUNT(*) FROM federationsender_queue_edus" +
|
|
||||||
" WHERE server_name = $1"
|
|
||||||
|
|
||||||
const selectQueueServerNamesSQL = "" +
|
const selectQueueServerNamesSQL = "" +
|
||||||
"SELECT DISTINCT server_name FROM federationsender_queue_edus"
|
"SELECT DISTINCT server_name FROM federationsender_queue_edus"
|
||||||
|
|
||||||
|
|
@ -81,7 +77,6 @@ type queueEDUsStatements struct {
|
||||||
deleteQueueEDUStmt *sql.Stmt
|
deleteQueueEDUStmt *sql.Stmt
|
||||||
selectQueueEDUStmt *sql.Stmt
|
selectQueueEDUStmt *sql.Stmt
|
||||||
selectQueueEDUReferenceJSONCountStmt *sql.Stmt
|
selectQueueEDUReferenceJSONCountStmt *sql.Stmt
|
||||||
selectQueueEDUCountStmt *sql.Stmt
|
|
||||||
selectQueueEDUServerNamesStmt *sql.Stmt
|
selectQueueEDUServerNamesStmt *sql.Stmt
|
||||||
selectExpiredEDUsStmt *sql.Stmt
|
selectExpiredEDUsStmt *sql.Stmt
|
||||||
deleteExpiredEDUsStmt *sql.Stmt
|
deleteExpiredEDUsStmt *sql.Stmt
|
||||||
|
|
@ -116,7 +111,6 @@ func (s *queueEDUsStatements) Prepare() error {
|
||||||
{&s.deleteQueueEDUStmt, deleteQueueEDUSQL},
|
{&s.deleteQueueEDUStmt, deleteQueueEDUSQL},
|
||||||
{&s.selectQueueEDUStmt, selectQueueEDUSQL},
|
{&s.selectQueueEDUStmt, selectQueueEDUSQL},
|
||||||
{&s.selectQueueEDUReferenceJSONCountStmt, selectQueueEDUReferenceJSONCountSQL},
|
{&s.selectQueueEDUReferenceJSONCountStmt, selectQueueEDUReferenceJSONCountSQL},
|
||||||
{&s.selectQueueEDUCountStmt, selectQueueEDUCountSQL},
|
|
||||||
{&s.selectQueueEDUServerNamesStmt, selectQueueServerNamesSQL},
|
{&s.selectQueueEDUServerNamesStmt, selectQueueServerNamesSQL},
|
||||||
{&s.selectExpiredEDUsStmt, selectExpiredEDUsSQL},
|
{&s.selectExpiredEDUsStmt, selectExpiredEDUsSQL},
|
||||||
{&s.deleteExpiredEDUsStmt, deleteExpiredEDUsSQL},
|
{&s.deleteExpiredEDUsStmt, deleteExpiredEDUsSQL},
|
||||||
|
|
@ -186,21 +180,6 @@ func (s *queueEDUsStatements) SelectQueueEDUReferenceJSONCount(
|
||||||
return count, err
|
return count, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *queueEDUsStatements) SelectQueueEDUCount(
|
|
||||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
|
|
||||||
) (int64, error) {
|
|
||||||
var count int64
|
|
||||||
stmt := sqlutil.TxStmt(txn, s.selectQueueEDUCountStmt)
|
|
||||||
err := stmt.QueryRowContext(ctx, serverName).Scan(&count)
|
|
||||||
if err == sql.ErrNoRows {
|
|
||||||
// It's acceptable for there to be no rows referencing a given
|
|
||||||
// JSON NID but it's not an error condition. Just return as if
|
|
||||||
// there's a zero count.
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
return count, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *queueEDUsStatements) SelectQueueEDUServerNames(
|
func (s *queueEDUsStatements) SelectQueueEDUServerNames(
|
||||||
ctx context.Context, txn *sql.Tx,
|
ctx context.Context, txn *sql.Tx,
|
||||||
) ([]gomatrixserverlib.ServerName, error) {
|
) ([]gomatrixserverlib.ServerName, error) {
|
||||||
|
|
|
||||||
|
|
@ -58,10 +58,6 @@ const selectQueuePDUReferenceJSONCountSQL = "" +
|
||||||
"SELECT COUNT(*) FROM federationsender_queue_pdus" +
|
"SELECT COUNT(*) FROM federationsender_queue_pdus" +
|
||||||
" WHERE json_nid = $1"
|
" WHERE json_nid = $1"
|
||||||
|
|
||||||
const selectQueuePDUsCountSQL = "" +
|
|
||||||
"SELECT COUNT(*) FROM federationsender_queue_pdus" +
|
|
||||||
" WHERE server_name = $1"
|
|
||||||
|
|
||||||
const selectQueuePDUServerNamesSQL = "" +
|
const selectQueuePDUServerNamesSQL = "" +
|
||||||
"SELECT DISTINCT server_name FROM federationsender_queue_pdus"
|
"SELECT DISTINCT server_name FROM federationsender_queue_pdus"
|
||||||
|
|
||||||
|
|
@ -71,7 +67,6 @@ type queuePDUsStatements struct {
|
||||||
deleteQueuePDUsStmt *sql.Stmt
|
deleteQueuePDUsStmt *sql.Stmt
|
||||||
selectQueuePDUsStmt *sql.Stmt
|
selectQueuePDUsStmt *sql.Stmt
|
||||||
selectQueuePDUReferenceJSONCountStmt *sql.Stmt
|
selectQueuePDUReferenceJSONCountStmt *sql.Stmt
|
||||||
selectQueuePDUsCountStmt *sql.Stmt
|
|
||||||
selectQueuePDUServerNamesStmt *sql.Stmt
|
selectQueuePDUServerNamesStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -95,9 +90,6 @@ func NewPostgresQueuePDUsTable(db *sql.DB) (s *queuePDUsStatements, err error) {
|
||||||
if s.selectQueuePDUReferenceJSONCountStmt, err = s.db.Prepare(selectQueuePDUReferenceJSONCountSQL); err != nil {
|
if s.selectQueuePDUReferenceJSONCountStmt, err = s.db.Prepare(selectQueuePDUReferenceJSONCountSQL); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if s.selectQueuePDUsCountStmt, err = s.db.Prepare(selectQueuePDUsCountSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.selectQueuePDUServerNamesStmt, err = s.db.Prepare(selectQueuePDUServerNamesSQL); err != nil {
|
if s.selectQueuePDUServerNamesStmt, err = s.db.Prepare(selectQueuePDUServerNamesSQL); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
@ -146,21 +138,6 @@ func (s *queuePDUsStatements) SelectQueuePDUReferenceJSONCount(
|
||||||
return count, err
|
return count, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *queuePDUsStatements) SelectQueuePDUCount(
|
|
||||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
|
|
||||||
) (int64, error) {
|
|
||||||
var count int64
|
|
||||||
stmt := sqlutil.TxStmt(txn, s.selectQueuePDUsCountStmt)
|
|
||||||
err := stmt.QueryRowContext(ctx, serverName).Scan(&count)
|
|
||||||
if err == sql.ErrNoRows {
|
|
||||||
// It's acceptable for there to be no rows referencing a given
|
|
||||||
// JSON NID but it's not an error condition. Just return as if
|
|
||||||
// there's a zero count.
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
return count, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *queuePDUsStatements) SelectQueuePDUs(
|
func (s *queuePDUsStatements) SelectQueuePDUs(
|
||||||
ctx context.Context, txn *sql.Tx,
|
ctx context.Context, txn *sql.Tx,
|
||||||
serverName gomatrixserverlib.ServerName,
|
serverName gomatrixserverlib.ServerName,
|
||||||
|
|
|
||||||
|
|
@ -162,15 +162,6 @@ func (d *Database) CleanEDUs(
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPendingEDUCount returns the number of EDUs waiting to be
|
|
||||||
// sent for a given servername.
|
|
||||||
func (d *Database) GetPendingEDUCount(
|
|
||||||
ctx context.Context,
|
|
||||||
serverName gomatrixserverlib.ServerName,
|
|
||||||
) (int64, error) {
|
|
||||||
return d.FederationQueueEDUs.SelectQueueEDUCount(ctx, nil, serverName)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetPendingServerNames returns the server names that have EDUs
|
// GetPendingServerNames returns the server names that have EDUs
|
||||||
// waiting to be sent.
|
// waiting to be sent.
|
||||||
func (d *Database) GetPendingEDUServerNames(
|
func (d *Database) GetPendingEDUServerNames(
|
||||||
|
|
|
||||||
|
|
@ -141,15 +141,6 @@ func (d *Database) CleanPDUs(
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPendingPDUCount returns the number of PDUs waiting to be
|
|
||||||
// sent for a given servername.
|
|
||||||
func (d *Database) GetPendingPDUCount(
|
|
||||||
ctx context.Context,
|
|
||||||
serverName gomatrixserverlib.ServerName,
|
|
||||||
) (int64, error) {
|
|
||||||
return d.FederationQueuePDUs.SelectQueuePDUCount(ctx, nil, serverName)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetPendingServerNames returns the server names that have PDUs
|
// GetPendingServerNames returns the server names that have PDUs
|
||||||
// waiting to be sent.
|
// waiting to be sent.
|
||||||
func (d *Database) GetPendingPDUServerNames(
|
func (d *Database) GetPendingPDUServerNames(
|
||||||
|
|
|
||||||
|
|
@ -44,13 +44,13 @@ const selectInboundPeekSQL = "" +
|
||||||
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3"
|
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3"
|
||||||
|
|
||||||
const selectInboundPeeksSQL = "" +
|
const selectInboundPeeksSQL = "" +
|
||||||
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1"
|
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1 ORDER BY creation_ts"
|
||||||
|
|
||||||
const renewInboundPeekSQL = "" +
|
const renewInboundPeekSQL = "" +
|
||||||
"UPDATE federationsender_inbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5"
|
"UPDATE federationsender_inbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5"
|
||||||
|
|
||||||
const deleteInboundPeekSQL = "" +
|
const deleteInboundPeekSQL = "" +
|
||||||
"DELETE FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2"
|
"DELETE FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3"
|
||||||
|
|
||||||
const deleteInboundPeeksSQL = "" +
|
const deleteInboundPeeksSQL = "" +
|
||||||
"DELETE FROM federationsender_inbound_peeks WHERE room_id = $1"
|
"DELETE FROM federationsender_inbound_peeks WHERE room_id = $1"
|
||||||
|
|
@ -74,25 +74,15 @@ func NewSQLiteInboundPeeksTable(db *sql.DB) (s *inboundPeeksStatements, err erro
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.insertInboundPeekStmt, err = db.Prepare(insertInboundPeekSQL); err != nil {
|
return s, sqlutil.StatementList{
|
||||||
return
|
{&s.insertInboundPeekStmt, insertInboundPeekSQL},
|
||||||
}
|
{&s.selectInboundPeekStmt, selectInboundPeekSQL},
|
||||||
if s.selectInboundPeekStmt, err = db.Prepare(selectInboundPeekSQL); err != nil {
|
{&s.selectInboundPeekStmt, selectInboundPeekSQL},
|
||||||
return
|
{&s.selectInboundPeeksStmt, selectInboundPeeksSQL},
|
||||||
}
|
{&s.renewInboundPeekStmt, renewInboundPeekSQL},
|
||||||
if s.selectInboundPeeksStmt, err = db.Prepare(selectInboundPeeksSQL); err != nil {
|
{&s.deleteInboundPeeksStmt, deleteInboundPeeksSQL},
|
||||||
return
|
{&s.deleteInboundPeekStmt, deleteInboundPeekSQL},
|
||||||
}
|
}.Prepare(db)
|
||||||
if s.renewInboundPeekStmt, err = db.Prepare(renewInboundPeekSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.deleteInboundPeeksStmt, err = db.Prepare(deleteInboundPeeksSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.deleteInboundPeekStmt, err = db.Prepare(deleteInboundPeekSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *inboundPeeksStatements) InsertInboundPeek(
|
func (s *inboundPeeksStatements) InsertInboundPeek(
|
||||||
|
|
|
||||||
|
|
@ -44,13 +44,13 @@ const selectOutboundPeekSQL = "" +
|
||||||
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3"
|
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3"
|
||||||
|
|
||||||
const selectOutboundPeeksSQL = "" +
|
const selectOutboundPeeksSQL = "" +
|
||||||
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1"
|
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1 ORDER BY creation_ts"
|
||||||
|
|
||||||
const renewOutboundPeekSQL = "" +
|
const renewOutboundPeekSQL = "" +
|
||||||
"UPDATE federationsender_outbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5"
|
"UPDATE federationsender_outbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5"
|
||||||
|
|
||||||
const deleteOutboundPeekSQL = "" +
|
const deleteOutboundPeekSQL = "" +
|
||||||
"DELETE FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2"
|
"DELETE FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3"
|
||||||
|
|
||||||
const deleteOutboundPeeksSQL = "" +
|
const deleteOutboundPeeksSQL = "" +
|
||||||
"DELETE FROM federationsender_outbound_peeks WHERE room_id = $1"
|
"DELETE FROM federationsender_outbound_peeks WHERE room_id = $1"
|
||||||
|
|
@ -74,25 +74,14 @@ func NewSQLiteOutboundPeeksTable(db *sql.DB) (s *outboundPeeksStatements, err er
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.insertOutboundPeekStmt, err = db.Prepare(insertOutboundPeekSQL); err != nil {
|
return s, sqlutil.StatementList{
|
||||||
return
|
{&s.insertOutboundPeekStmt, insertOutboundPeekSQL},
|
||||||
}
|
{&s.selectOutboundPeekStmt, selectOutboundPeekSQL},
|
||||||
if s.selectOutboundPeekStmt, err = db.Prepare(selectOutboundPeekSQL); err != nil {
|
{&s.selectOutboundPeeksStmt, selectOutboundPeeksSQL},
|
||||||
return
|
{&s.renewOutboundPeekStmt, renewOutboundPeekSQL},
|
||||||
}
|
{&s.deleteOutboundPeeksStmt, deleteOutboundPeeksSQL},
|
||||||
if s.selectOutboundPeeksStmt, err = db.Prepare(selectOutboundPeeksSQL); err != nil {
|
{&s.deleteOutboundPeekStmt, deleteOutboundPeekSQL},
|
||||||
return
|
}.Prepare(db)
|
||||||
}
|
|
||||||
if s.renewOutboundPeekStmt, err = db.Prepare(renewOutboundPeekSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.deleteOutboundPeeksStmt, err = db.Prepare(deleteOutboundPeeksSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.deleteOutboundPeekStmt, err = db.Prepare(deleteOutboundPeekSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *outboundPeeksStatements) InsertOutboundPeek(
|
func (s *outboundPeeksStatements) InsertOutboundPeek(
|
||||||
|
|
|
||||||
|
|
@ -63,10 +63,6 @@ const selectQueueEDUReferenceJSONCountSQL = "" +
|
||||||
"SELECT COUNT(*) FROM federationsender_queue_edus" +
|
"SELECT COUNT(*) FROM federationsender_queue_edus" +
|
||||||
" WHERE json_nid = $1"
|
" WHERE json_nid = $1"
|
||||||
|
|
||||||
const selectQueueEDUCountSQL = "" +
|
|
||||||
"SELECT COUNT(*) FROM federationsender_queue_edus" +
|
|
||||||
" WHERE server_name = $1"
|
|
||||||
|
|
||||||
const selectQueueServerNamesSQL = "" +
|
const selectQueueServerNamesSQL = "" +
|
||||||
"SELECT DISTINCT server_name FROM federationsender_queue_edus"
|
"SELECT DISTINCT server_name FROM federationsender_queue_edus"
|
||||||
|
|
||||||
|
|
@ -82,7 +78,6 @@ type queueEDUsStatements struct {
|
||||||
// deleteQueueEDUStmt *sql.Stmt - prepared at runtime due to variadic
|
// deleteQueueEDUStmt *sql.Stmt - prepared at runtime due to variadic
|
||||||
selectQueueEDUStmt *sql.Stmt
|
selectQueueEDUStmt *sql.Stmt
|
||||||
selectQueueEDUReferenceJSONCountStmt *sql.Stmt
|
selectQueueEDUReferenceJSONCountStmt *sql.Stmt
|
||||||
selectQueueEDUCountStmt *sql.Stmt
|
|
||||||
selectQueueEDUServerNamesStmt *sql.Stmt
|
selectQueueEDUServerNamesStmt *sql.Stmt
|
||||||
selectExpiredEDUsStmt *sql.Stmt
|
selectExpiredEDUsStmt *sql.Stmt
|
||||||
deleteExpiredEDUsStmt *sql.Stmt
|
deleteExpiredEDUsStmt *sql.Stmt
|
||||||
|
|
@ -116,7 +111,6 @@ func (s *queueEDUsStatements) Prepare() error {
|
||||||
{&s.insertQueueEDUStmt, insertQueueEDUSQL},
|
{&s.insertQueueEDUStmt, insertQueueEDUSQL},
|
||||||
{&s.selectQueueEDUStmt, selectQueueEDUSQL},
|
{&s.selectQueueEDUStmt, selectQueueEDUSQL},
|
||||||
{&s.selectQueueEDUReferenceJSONCountStmt, selectQueueEDUReferenceJSONCountSQL},
|
{&s.selectQueueEDUReferenceJSONCountStmt, selectQueueEDUReferenceJSONCountSQL},
|
||||||
{&s.selectQueueEDUCountStmt, selectQueueEDUCountSQL},
|
|
||||||
{&s.selectQueueEDUServerNamesStmt, selectQueueServerNamesSQL},
|
{&s.selectQueueEDUServerNamesStmt, selectQueueServerNamesSQL},
|
||||||
{&s.selectExpiredEDUsStmt, selectExpiredEDUsSQL},
|
{&s.selectExpiredEDUsStmt, selectExpiredEDUsSQL},
|
||||||
{&s.deleteExpiredEDUsStmt, deleteExpiredEDUsSQL},
|
{&s.deleteExpiredEDUsStmt, deleteExpiredEDUsSQL},
|
||||||
|
|
@ -198,21 +192,6 @@ func (s *queueEDUsStatements) SelectQueueEDUReferenceJSONCount(
|
||||||
return count, err
|
return count, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *queueEDUsStatements) SelectQueueEDUCount(
|
|
||||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
|
|
||||||
) (int64, error) {
|
|
||||||
var count int64
|
|
||||||
stmt := sqlutil.TxStmt(txn, s.selectQueueEDUCountStmt)
|
|
||||||
err := stmt.QueryRowContext(ctx, serverName).Scan(&count)
|
|
||||||
if err == sql.ErrNoRows {
|
|
||||||
// It's acceptable for there to be no rows referencing a given
|
|
||||||
// JSON NID but it's not an error condition. Just return as if
|
|
||||||
// there's a zero count.
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
return count, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *queueEDUsStatements) SelectQueueEDUServerNames(
|
func (s *queueEDUsStatements) SelectQueueEDUServerNames(
|
||||||
ctx context.Context, txn *sql.Tx,
|
ctx context.Context, txn *sql.Tx,
|
||||||
) ([]gomatrixserverlib.ServerName, error) {
|
) ([]gomatrixserverlib.ServerName, error) {
|
||||||
|
|
|
||||||
|
|
@ -66,10 +66,6 @@ const selectQueuePDUsReferenceJSONCountSQL = "" +
|
||||||
"SELECT COUNT(*) FROM federationsender_queue_pdus" +
|
"SELECT COUNT(*) FROM federationsender_queue_pdus" +
|
||||||
" WHERE json_nid = $1"
|
" WHERE json_nid = $1"
|
||||||
|
|
||||||
const selectQueuePDUsCountSQL = "" +
|
|
||||||
"SELECT COUNT(*) FROM federationsender_queue_pdus" +
|
|
||||||
" WHERE server_name = $1"
|
|
||||||
|
|
||||||
const selectQueuePDUsServerNamesSQL = "" +
|
const selectQueuePDUsServerNamesSQL = "" +
|
||||||
"SELECT DISTINCT server_name FROM federationsender_queue_pdus"
|
"SELECT DISTINCT server_name FROM federationsender_queue_pdus"
|
||||||
|
|
||||||
|
|
@ -79,7 +75,6 @@ type queuePDUsStatements struct {
|
||||||
selectQueueNextTransactionIDStmt *sql.Stmt
|
selectQueueNextTransactionIDStmt *sql.Stmt
|
||||||
selectQueuePDUsStmt *sql.Stmt
|
selectQueuePDUsStmt *sql.Stmt
|
||||||
selectQueueReferenceJSONCountStmt *sql.Stmt
|
selectQueueReferenceJSONCountStmt *sql.Stmt
|
||||||
selectQueuePDUsCountStmt *sql.Stmt
|
|
||||||
selectQueueServerNamesStmt *sql.Stmt
|
selectQueueServerNamesStmt *sql.Stmt
|
||||||
// deleteQueuePDUsStmt *sql.Stmt - prepared at runtime due to variadic
|
// deleteQueuePDUsStmt *sql.Stmt - prepared at runtime due to variadic
|
||||||
}
|
}
|
||||||
|
|
@ -107,9 +102,6 @@ func NewSQLiteQueuePDUsTable(db *sql.DB) (s *queuePDUsStatements, err error) {
|
||||||
if s.selectQueueReferenceJSONCountStmt, err = db.Prepare(selectQueuePDUsReferenceJSONCountSQL); err != nil {
|
if s.selectQueueReferenceJSONCountStmt, err = db.Prepare(selectQueuePDUsReferenceJSONCountSQL); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if s.selectQueuePDUsCountStmt, err = db.Prepare(selectQueuePDUsCountSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.selectQueueServerNamesStmt, err = db.Prepare(selectQueuePDUsServerNamesSQL); err != nil {
|
if s.selectQueueServerNamesStmt, err = db.Prepare(selectQueuePDUsServerNamesSQL); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
@ -179,21 +171,6 @@ func (s *queuePDUsStatements) SelectQueuePDUReferenceJSONCount(
|
||||||
return count, err
|
return count, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *queuePDUsStatements) SelectQueuePDUCount(
|
|
||||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
|
|
||||||
) (int64, error) {
|
|
||||||
var count int64
|
|
||||||
stmt := sqlutil.TxStmt(txn, s.selectQueuePDUsCountStmt)
|
|
||||||
err := stmt.QueryRowContext(ctx, serverName).Scan(&count)
|
|
||||||
if err == sql.ErrNoRows {
|
|
||||||
// It's acceptable for there to be no rows referencing a given
|
|
||||||
// JSON NID but it's not an error condition. Just return as if
|
|
||||||
// there's a zero count.
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
return count, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *queuePDUsStatements) SelectQueuePDUs(
|
func (s *queuePDUsStatements) SelectQueuePDUs(
|
||||||
ctx context.Context, txn *sql.Tx,
|
ctx context.Context, txn *sql.Tx,
|
||||||
serverName gomatrixserverlib.ServerName,
|
serverName gomatrixserverlib.ServerName,
|
||||||
|
|
|
||||||
|
|
@ -2,10 +2,12 @@ package storage_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/matrix-org/util"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/federationapi/storage"
|
"github.com/matrix-org/dendrite/federationapi/storage"
|
||||||
|
|
@ -80,3 +82,167 @@ func TestExpireEDUs(t *testing.T) {
|
||||||
assert.Equal(t, 2, len(data))
|
assert.Equal(t, 2, len(data))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOutboundPeeking(t *testing.T) {
|
||||||
|
alice := test.NewUser(t)
|
||||||
|
room := test.NewRoom(t, alice)
|
||||||
|
_, serverName, _ := gomatrixserverlib.SplitID('@', alice.ID)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
db, closeDB := mustCreateFederationDatabase(t, dbType)
|
||||||
|
defer closeDB()
|
||||||
|
peekID := util.RandomString(8)
|
||||||
|
var renewalInterval int64 = 1000
|
||||||
|
|
||||||
|
// Add outbound peek
|
||||||
|
if err := db.AddOutboundPeek(ctx, serverName, room.ID, peekID, renewalInterval); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// select the newly inserted peek
|
||||||
|
outboundPeek1, err := db.GetOutboundPeek(ctx, serverName, room.ID, peekID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assert fields are set as expected
|
||||||
|
if outboundPeek1.PeekID != peekID {
|
||||||
|
t.Fatalf("unexpected outbound peek ID: %s, want %s", outboundPeek1.PeekID, peekID)
|
||||||
|
}
|
||||||
|
if outboundPeek1.RoomID != room.ID {
|
||||||
|
t.Fatalf("unexpected outbound peek room ID: %s, want %s", outboundPeek1.RoomID, peekID)
|
||||||
|
}
|
||||||
|
if outboundPeek1.ServerName != serverName {
|
||||||
|
t.Fatalf("unexpected outbound peek servername: %s, want %s", outboundPeek1.ServerName, serverName)
|
||||||
|
}
|
||||||
|
if outboundPeek1.RenewalInterval != renewalInterval {
|
||||||
|
t.Fatalf("unexpected outbound peek renewal interval: %d, want %d", outboundPeek1.RenewalInterval, renewalInterval)
|
||||||
|
}
|
||||||
|
// Renew the peek
|
||||||
|
if err = db.RenewOutboundPeek(ctx, serverName, room.ID, peekID, 2000); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// verify the values changed
|
||||||
|
outboundPeek2, err := db.GetOutboundPeek(ctx, serverName, room.ID, peekID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if reflect.DeepEqual(outboundPeek1, outboundPeek2) {
|
||||||
|
t.Fatal("expected a change peek, but they are the same")
|
||||||
|
}
|
||||||
|
if outboundPeek1.ServerName != outboundPeek2.ServerName {
|
||||||
|
t.Fatalf("unexpected servername change: %s -> %s", outboundPeek1.ServerName, outboundPeek2.ServerName)
|
||||||
|
}
|
||||||
|
if outboundPeek1.RoomID != outboundPeek2.RoomID {
|
||||||
|
t.Fatalf("unexpected roomID change: %s -> %s", outboundPeek1.RoomID, outboundPeek2.RoomID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// insert some peeks
|
||||||
|
peekIDs := []string{peekID}
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
peekID = util.RandomString(8)
|
||||||
|
if err = db.AddOutboundPeek(ctx, serverName, room.ID, peekID, 1000); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
peekIDs = append(peekIDs, peekID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now select them
|
||||||
|
outboundPeeks, err := db.GetOutboundPeeks(ctx, room.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(outboundPeeks) != len(peekIDs) {
|
||||||
|
t.Fatalf("inserted %d peeks, selected %d", len(peekIDs), len(outboundPeeks))
|
||||||
|
}
|
||||||
|
gotPeekIDs := make([]string, 0, len(outboundPeeks))
|
||||||
|
for _, p := range outboundPeeks {
|
||||||
|
gotPeekIDs = append(gotPeekIDs, p.PeekID)
|
||||||
|
}
|
||||||
|
assert.ElementsMatch(t, gotPeekIDs, peekIDs)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInboundPeeking(t *testing.T) {
|
||||||
|
alice := test.NewUser(t)
|
||||||
|
room := test.NewRoom(t, alice)
|
||||||
|
_, serverName, _ := gomatrixserverlib.SplitID('@', alice.ID)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
db, closeDB := mustCreateFederationDatabase(t, dbType)
|
||||||
|
defer closeDB()
|
||||||
|
peekID := util.RandomString(8)
|
||||||
|
var renewalInterval int64 = 1000
|
||||||
|
|
||||||
|
// Add inbound peek
|
||||||
|
if err := db.AddInboundPeek(ctx, serverName, room.ID, peekID, renewalInterval); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// select the newly inserted peek
|
||||||
|
inboundPeek1, err := db.GetInboundPeek(ctx, serverName, room.ID, peekID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assert fields are set as expected
|
||||||
|
if inboundPeek1.PeekID != peekID {
|
||||||
|
t.Fatalf("unexpected inbound peek ID: %s, want %s", inboundPeek1.PeekID, peekID)
|
||||||
|
}
|
||||||
|
if inboundPeek1.RoomID != room.ID {
|
||||||
|
t.Fatalf("unexpected inbound peek room ID: %s, want %s", inboundPeek1.RoomID, peekID)
|
||||||
|
}
|
||||||
|
if inboundPeek1.ServerName != serverName {
|
||||||
|
t.Fatalf("unexpected inbound peek servername: %s, want %s", inboundPeek1.ServerName, serverName)
|
||||||
|
}
|
||||||
|
if inboundPeek1.RenewalInterval != renewalInterval {
|
||||||
|
t.Fatalf("unexpected inbound peek renewal interval: %d, want %d", inboundPeek1.RenewalInterval, renewalInterval)
|
||||||
|
}
|
||||||
|
// Renew the peek
|
||||||
|
if err = db.RenewInboundPeek(ctx, serverName, room.ID, peekID, 2000); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// verify the values changed
|
||||||
|
inboundPeek2, err := db.GetInboundPeek(ctx, serverName, room.ID, peekID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if reflect.DeepEqual(inboundPeek1, inboundPeek2) {
|
||||||
|
t.Fatal("expected a change peek, but they are the same")
|
||||||
|
}
|
||||||
|
if inboundPeek1.ServerName != inboundPeek2.ServerName {
|
||||||
|
t.Fatalf("unexpected servername change: %s -> %s", inboundPeek1.ServerName, inboundPeek2.ServerName)
|
||||||
|
}
|
||||||
|
if inboundPeek1.RoomID != inboundPeek2.RoomID {
|
||||||
|
t.Fatalf("unexpected roomID change: %s -> %s", inboundPeek1.RoomID, inboundPeek2.RoomID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// insert some peeks
|
||||||
|
peekIDs := []string{peekID}
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
peekID = util.RandomString(8)
|
||||||
|
if err = db.AddInboundPeek(ctx, serverName, room.ID, peekID, 1000); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
peekIDs = append(peekIDs, peekID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now select them
|
||||||
|
inboundPeeks, err := db.GetInboundPeeks(ctx, room.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(inboundPeeks) != len(peekIDs) {
|
||||||
|
t.Fatalf("inserted %d peeks, selected %d", len(peekIDs), len(inboundPeeks))
|
||||||
|
}
|
||||||
|
gotPeekIDs := make([]string, 0, len(inboundPeeks))
|
||||||
|
for _, p := range inboundPeeks {
|
||||||
|
gotPeekIDs = append(gotPeekIDs, p.PeekID)
|
||||||
|
}
|
||||||
|
assert.ElementsMatch(t, gotPeekIDs, peekIDs)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
|
||||||
149
federationapi/storage/tables/inbound_peeks_table_test.go
Normal file
149
federationapi/storage/tables/inbound_peeks_table_test.go
Normal file
|
|
@ -0,0 +1,149 @@
|
||||||
|
package tables_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/federationapi/storage/postgres"
|
||||||
|
"github.com/matrix-org/dendrite/federationapi/storage/sqlite3"
|
||||||
|
"github.com/matrix-org/dendrite/federationapi/storage/tables"
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
|
"github.com/matrix-org/dendrite/test"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/matrix-org/util"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func mustCreateInboundpeeksTable(t *testing.T, dbType test.DBType) (tables.FederationInboundPeeks, func()) {
|
||||||
|
connStr, close := test.PrepareDBConnectionString(t, dbType)
|
||||||
|
db, err := sqlutil.Open(&config.DatabaseOptions{
|
||||||
|
ConnectionString: config.DataSource(connStr),
|
||||||
|
}, sqlutil.NewExclusiveWriter())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to open database: %s", err)
|
||||||
|
}
|
||||||
|
var tab tables.FederationInboundPeeks
|
||||||
|
switch dbType {
|
||||||
|
case test.DBTypePostgres:
|
||||||
|
tab, err = postgres.NewPostgresInboundPeeksTable(db)
|
||||||
|
case test.DBTypeSQLite:
|
||||||
|
tab, err = sqlite3.NewSQLiteInboundPeeksTable(db)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create table: %s", err)
|
||||||
|
}
|
||||||
|
return tab, close
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInboundPeeksTable(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
alice := test.NewUser(t)
|
||||||
|
room := test.NewRoom(t, alice)
|
||||||
|
_, serverName, _ := gomatrixserverlib.SplitID('@', alice.ID)
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
tab, closeDB := mustCreateInboundpeeksTable(t, dbType)
|
||||||
|
defer closeDB()
|
||||||
|
|
||||||
|
// Insert a peek
|
||||||
|
peekID := util.RandomString(8)
|
||||||
|
var renewalInterval int64 = 1000
|
||||||
|
if err := tab.InsertInboundPeek(ctx, nil, serverName, room.ID, peekID, renewalInterval); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// select the newly inserted peek
|
||||||
|
inboundPeek1, err := tab.SelectInboundPeek(ctx, nil, serverName, room.ID, peekID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assert fields are set as expected
|
||||||
|
if inboundPeek1.PeekID != peekID {
|
||||||
|
t.Fatalf("unexpected inbound peek ID: %s, want %s", inboundPeek1.PeekID, peekID)
|
||||||
|
}
|
||||||
|
if inboundPeek1.RoomID != room.ID {
|
||||||
|
t.Fatalf("unexpected inbound peek room ID: %s, want %s", inboundPeek1.RoomID, peekID)
|
||||||
|
}
|
||||||
|
if inboundPeek1.ServerName != serverName {
|
||||||
|
t.Fatalf("unexpected inbound peek servername: %s, want %s", inboundPeek1.ServerName, serverName)
|
||||||
|
}
|
||||||
|
if inboundPeek1.RenewalInterval != renewalInterval {
|
||||||
|
t.Fatalf("unexpected inbound peek renewal interval: %d, want %d", inboundPeek1.RenewalInterval, renewalInterval)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Renew the peek
|
||||||
|
if err = tab.RenewInboundPeek(ctx, nil, serverName, room.ID, peekID, 2000); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// verify the values changed
|
||||||
|
inboundPeek2, err := tab.SelectInboundPeek(ctx, nil, serverName, room.ID, peekID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if reflect.DeepEqual(inboundPeek1, inboundPeek2) {
|
||||||
|
t.Fatal("expected a change peek, but they are the same")
|
||||||
|
}
|
||||||
|
if inboundPeek1.ServerName != inboundPeek2.ServerName {
|
||||||
|
t.Fatalf("unexpected servername change: %s -> %s", inboundPeek1.ServerName, inboundPeek2.ServerName)
|
||||||
|
}
|
||||||
|
if inboundPeek1.RoomID != inboundPeek2.RoomID {
|
||||||
|
t.Fatalf("unexpected roomID change: %s -> %s", inboundPeek1.RoomID, inboundPeek2.RoomID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// delete the peek
|
||||||
|
if err = tab.DeleteInboundPeek(ctx, nil, serverName, room.ID, peekID); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// There should be no peek anymore
|
||||||
|
peek, err := tab.SelectInboundPeek(ctx, nil, serverName, room.ID, peekID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if peek != nil {
|
||||||
|
t.Fatalf("got a peek which should be deleted: %+v", peek)
|
||||||
|
}
|
||||||
|
|
||||||
|
// insert some peeks
|
||||||
|
var peekIDs []string
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
peekID = util.RandomString(8)
|
||||||
|
if err = tab.InsertInboundPeek(ctx, nil, serverName, room.ID, peekID, 1000); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
peekIDs = append(peekIDs, peekID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now select them
|
||||||
|
inboundPeeks, err := tab.SelectInboundPeeks(ctx, nil, room.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(inboundPeeks) != len(peekIDs) {
|
||||||
|
t.Fatalf("inserted %d peeks, selected %d", len(peekIDs), len(inboundPeeks))
|
||||||
|
}
|
||||||
|
gotPeekIDs := make([]string, 0, len(inboundPeeks))
|
||||||
|
for _, p := range inboundPeeks {
|
||||||
|
gotPeekIDs = append(gotPeekIDs, p.PeekID)
|
||||||
|
}
|
||||||
|
assert.ElementsMatch(t, gotPeekIDs, peekIDs)
|
||||||
|
|
||||||
|
// And delete them again
|
||||||
|
if err = tab.DeleteInboundPeeks(ctx, nil, room.ID); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// they should be gone now
|
||||||
|
inboundPeeks, err = tab.SelectInboundPeeks(ctx, nil, room.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(inboundPeeks) > 0 {
|
||||||
|
t.Fatal("got inbound peeks which should be deleted")
|
||||||
|
}
|
||||||
|
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
@ -28,7 +28,6 @@ type FederationQueuePDUs interface {
|
||||||
InsertQueuePDU(ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nid int64) error
|
InsertQueuePDU(ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nid int64) error
|
||||||
DeleteQueuePDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error
|
DeleteQueuePDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error
|
||||||
SelectQueuePDUReferenceJSONCount(ctx context.Context, txn *sql.Tx, jsonNID int64) (int64, error)
|
SelectQueuePDUReferenceJSONCount(ctx context.Context, txn *sql.Tx, jsonNID int64) (int64, error)
|
||||||
SelectQueuePDUCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error)
|
|
||||||
SelectQueuePDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error)
|
SelectQueuePDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error)
|
||||||
SelectQueuePDUServerNames(ctx context.Context, txn *sql.Tx) ([]gomatrixserverlib.ServerName, error)
|
SelectQueuePDUServerNames(ctx context.Context, txn *sql.Tx) ([]gomatrixserverlib.ServerName, error)
|
||||||
}
|
}
|
||||||
|
|
@ -38,7 +37,6 @@ type FederationQueueEDUs interface {
|
||||||
DeleteQueueEDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error
|
DeleteQueueEDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error
|
||||||
SelectQueueEDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error)
|
SelectQueueEDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error)
|
||||||
SelectQueueEDUReferenceJSONCount(ctx context.Context, txn *sql.Tx, jsonNID int64) (int64, error)
|
SelectQueueEDUReferenceJSONCount(ctx context.Context, txn *sql.Tx, jsonNID int64) (int64, error)
|
||||||
SelectQueueEDUCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error)
|
|
||||||
SelectQueueEDUServerNames(ctx context.Context, txn *sql.Tx) ([]gomatrixserverlib.ServerName, error)
|
SelectQueueEDUServerNames(ctx context.Context, txn *sql.Tx) ([]gomatrixserverlib.ServerName, error)
|
||||||
SelectExpiredEDUs(ctx context.Context, txn *sql.Tx, expiredBefore gomatrixserverlib.Timestamp) ([]int64, error)
|
SelectExpiredEDUs(ctx context.Context, txn *sql.Tx, expiredBefore gomatrixserverlib.Timestamp) ([]int64, error)
|
||||||
DeleteExpiredEDUs(ctx context.Context, txn *sql.Tx, expiredBefore gomatrixserverlib.Timestamp) error
|
DeleteExpiredEDUs(ctx context.Context, txn *sql.Tx, expiredBefore gomatrixserverlib.Timestamp) error
|
||||||
|
|
|
||||||
148
federationapi/storage/tables/outbound_peeks_table_test.go
Normal file
148
federationapi/storage/tables/outbound_peeks_table_test.go
Normal file
|
|
@ -0,0 +1,148 @@
|
||||||
|
package tables_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/federationapi/storage/postgres"
|
||||||
|
"github.com/matrix-org/dendrite/federationapi/storage/sqlite3"
|
||||||
|
"github.com/matrix-org/dendrite/federationapi/storage/tables"
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
|
"github.com/matrix-org/dendrite/test"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/matrix-org/util"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func mustCreateOutboundpeeksTable(t *testing.T, dbType test.DBType) (tables.FederationOutboundPeeks, func()) {
|
||||||
|
connStr, close := test.PrepareDBConnectionString(t, dbType)
|
||||||
|
db, err := sqlutil.Open(&config.DatabaseOptions{
|
||||||
|
ConnectionString: config.DataSource(connStr),
|
||||||
|
}, sqlutil.NewExclusiveWriter())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to open database: %s", err)
|
||||||
|
}
|
||||||
|
var tab tables.FederationOutboundPeeks
|
||||||
|
switch dbType {
|
||||||
|
case test.DBTypePostgres:
|
||||||
|
tab, err = postgres.NewPostgresOutboundPeeksTable(db)
|
||||||
|
case test.DBTypeSQLite:
|
||||||
|
tab, err = sqlite3.NewSQLiteOutboundPeeksTable(db)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create table: %s", err)
|
||||||
|
}
|
||||||
|
return tab, close
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOutboundPeeksTable(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
alice := test.NewUser(t)
|
||||||
|
room := test.NewRoom(t, alice)
|
||||||
|
_, serverName, _ := gomatrixserverlib.SplitID('@', alice.ID)
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
tab, closeDB := mustCreateOutboundpeeksTable(t, dbType)
|
||||||
|
defer closeDB()
|
||||||
|
|
||||||
|
// Insert a peek
|
||||||
|
peekID := util.RandomString(8)
|
||||||
|
var renewalInterval int64 = 1000
|
||||||
|
if err := tab.InsertOutboundPeek(ctx, nil, serverName, room.ID, peekID, renewalInterval); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// select the newly inserted peek
|
||||||
|
outboundPeek1, err := tab.SelectOutboundPeek(ctx, nil, serverName, room.ID, peekID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assert fields are set as expected
|
||||||
|
if outboundPeek1.PeekID != peekID {
|
||||||
|
t.Fatalf("unexpected outbound peek ID: %s, want %s", outboundPeek1.PeekID, peekID)
|
||||||
|
}
|
||||||
|
if outboundPeek1.RoomID != room.ID {
|
||||||
|
t.Fatalf("unexpected outbound peek room ID: %s, want %s", outboundPeek1.RoomID, peekID)
|
||||||
|
}
|
||||||
|
if outboundPeek1.ServerName != serverName {
|
||||||
|
t.Fatalf("unexpected outbound peek servername: %s, want %s", outboundPeek1.ServerName, serverName)
|
||||||
|
}
|
||||||
|
if outboundPeek1.RenewalInterval != renewalInterval {
|
||||||
|
t.Fatalf("unexpected outbound peek renewal interval: %d, want %d", outboundPeek1.RenewalInterval, renewalInterval)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Renew the peek
|
||||||
|
if err = tab.RenewOutboundPeek(ctx, nil, serverName, room.ID, peekID, 2000); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// verify the values changed
|
||||||
|
outboundPeek2, err := tab.SelectOutboundPeek(ctx, nil, serverName, room.ID, peekID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if reflect.DeepEqual(outboundPeek1, outboundPeek2) {
|
||||||
|
t.Fatal("expected a change peek, but they are the same")
|
||||||
|
}
|
||||||
|
if outboundPeek1.ServerName != outboundPeek2.ServerName {
|
||||||
|
t.Fatalf("unexpected servername change: %s -> %s", outboundPeek1.ServerName, outboundPeek2.ServerName)
|
||||||
|
}
|
||||||
|
if outboundPeek1.RoomID != outboundPeek2.RoomID {
|
||||||
|
t.Fatalf("unexpected roomID change: %s -> %s", outboundPeek1.RoomID, outboundPeek2.RoomID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// delete the peek
|
||||||
|
if err = tab.DeleteOutboundPeek(ctx, nil, serverName, room.ID, peekID); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// There should be no peek anymore
|
||||||
|
peek, err := tab.SelectOutboundPeek(ctx, nil, serverName, room.ID, peekID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if peek != nil {
|
||||||
|
t.Fatalf("got a peek which should be deleted: %+v", peek)
|
||||||
|
}
|
||||||
|
|
||||||
|
// insert some peeks
|
||||||
|
var peekIDs []string
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
peekID = util.RandomString(8)
|
||||||
|
if err = tab.InsertOutboundPeek(ctx, nil, serverName, room.ID, peekID, 1000); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
peekIDs = append(peekIDs, peekID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now select them
|
||||||
|
outboundPeeks, err := tab.SelectOutboundPeeks(ctx, nil, room.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(outboundPeeks) != len(peekIDs) {
|
||||||
|
t.Fatalf("inserted %d peeks, selected %d", len(peekIDs), len(outboundPeeks))
|
||||||
|
}
|
||||||
|
gotPeekIDs := make([]string, 0, len(outboundPeeks))
|
||||||
|
for _, p := range outboundPeeks {
|
||||||
|
gotPeekIDs = append(gotPeekIDs, p.PeekID)
|
||||||
|
}
|
||||||
|
assert.ElementsMatch(t, gotPeekIDs, peekIDs)
|
||||||
|
|
||||||
|
// And delete them again
|
||||||
|
if err = tab.DeleteOutboundPeeks(ctx, nil, room.ID); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// they should be gone now
|
||||||
|
outboundPeeks, err = tab.SelectOutboundPeeks(ctx, nil, room.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(outboundPeeks) > 0 {
|
||||||
|
t.Fatal("got outbound peeks which should be deleted")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
@ -198,7 +198,7 @@ func MakeExternalAPI(metricsName string, f func(*http.Request) util.JSONResponse
|
||||||
|
|
||||||
// MakeHTMLAPI adds Span metrics to the HTML Handler function
|
// MakeHTMLAPI adds Span metrics to the HTML Handler function
|
||||||
// This is used to serve HTML alongside JSON error messages
|
// This is used to serve HTML alongside JSON error messages
|
||||||
func MakeHTMLAPI(metricsName string, f func(http.ResponseWriter, *http.Request) *util.JSONResponse) http.Handler {
|
func MakeHTMLAPI(metricsName string, enableMetrics bool, f func(http.ResponseWriter, *http.Request) *util.JSONResponse) http.Handler {
|
||||||
withSpan := func(w http.ResponseWriter, req *http.Request) {
|
withSpan := func(w http.ResponseWriter, req *http.Request) {
|
||||||
span := opentracing.StartSpan(metricsName)
|
span := opentracing.StartSpan(metricsName)
|
||||||
defer span.Finish()
|
defer span.Finish()
|
||||||
|
|
@ -211,6 +211,10 @@ func MakeHTMLAPI(metricsName string, f func(http.ResponseWriter, *http.Request)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !enableMetrics {
|
||||||
|
return http.HandlerFunc(withSpan)
|
||||||
|
}
|
||||||
|
|
||||||
return promhttp.InstrumentHandlerCounter(
|
return promhttp.InstrumentHandlerCounter(
|
||||||
promauto.NewCounterVec(
|
promauto.NewCounterVec(
|
||||||
prometheus.CounterOpts{
|
prometheus.CounterOpts{
|
||||||
|
|
|
||||||
44
internal/validate.go
Normal file
44
internal/validate.go
Normal file
|
|
@ -0,0 +1,44 @@
|
||||||
|
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||||
|
"github.com/matrix-org/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
const minPasswordLength = 8 // http://matrix.org/docs/spec/client_server/r0.2.0.html#password-based
|
||||||
|
|
||||||
|
const maxPasswordLength = 512 // https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161
|
||||||
|
|
||||||
|
// ValidatePassword returns an error response if the password is invalid
|
||||||
|
func ValidatePassword(password string) *util.JSONResponse {
|
||||||
|
// https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161
|
||||||
|
if len(password) > maxPasswordLength {
|
||||||
|
return &util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: jsonerror.BadJSON(fmt.Sprintf("password too long: max %d characters", maxPasswordLength)),
|
||||||
|
}
|
||||||
|
} else if len(password) > 0 && len(password) < minPasswordLength {
|
||||||
|
return &util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: jsonerror.WeakPassword(fmt.Sprintf("password too weak: min %d chars", minPasswordLength)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
@ -78,6 +78,7 @@ const (
|
||||||
type PerformJoinRequest struct {
|
type PerformJoinRequest struct {
|
||||||
RoomIDOrAlias string `json:"room_id_or_alias"`
|
RoomIDOrAlias string `json:"room_id_or_alias"`
|
||||||
UserID string `json:"user_id"`
|
UserID string `json:"user_id"`
|
||||||
|
IsGuest bool `json:"is_guest"`
|
||||||
Content map[string]interface{} `json:"content"`
|
Content map[string]interface{} `json:"content"`
|
||||||
ServerNames []gomatrixserverlib.ServerName `json:"server_names"`
|
ServerNames []gomatrixserverlib.ServerName `json:"server_names"`
|
||||||
Unsigned map[string]interface{} `json:"unsigned"`
|
Unsigned map[string]interface{} `json:"unsigned"`
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,10 @@ import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
"github.com/getsentry/sentry-go"
|
"github.com/getsentry/sentry-go"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/nats-io/nats.go"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
|
||||||
asAPI "github.com/matrix-org/dendrite/appservice/api"
|
asAPI "github.com/matrix-org/dendrite/appservice/api"
|
||||||
fsAPI "github.com/matrix-org/dendrite/federationapi/api"
|
fsAPI "github.com/matrix-org/dendrite/federationapi/api"
|
||||||
"github.com/matrix-org/dendrite/internal/caching"
|
"github.com/matrix-org/dendrite/internal/caching"
|
||||||
|
|
@ -19,9 +23,6 @@ import (
|
||||||
"github.com/matrix-org/dendrite/setup/jetstream"
|
"github.com/matrix-org/dendrite/setup/jetstream"
|
||||||
"github.com/matrix-org/dendrite/setup/process"
|
"github.com/matrix-org/dendrite/setup/process"
|
||||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
|
||||||
"github.com/nats-io/nats.go"
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// RoomserverInternalAPI is an implementation of api.RoomserverInternalAPI
|
// RoomserverInternalAPI is an implementation of api.RoomserverInternalAPI
|
||||||
|
|
@ -104,6 +105,11 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio
|
||||||
r.fsAPI = fsAPI
|
r.fsAPI = fsAPI
|
||||||
r.KeyRing = keyRing
|
r.KeyRing = keyRing
|
||||||
|
|
||||||
|
identity, err := r.Cfg.Matrix.SigningIdentityFor(r.ServerName)
|
||||||
|
if err != nil {
|
||||||
|
logrus.Panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
r.Inputer = &input.Inputer{
|
r.Inputer = &input.Inputer{
|
||||||
Cfg: &r.Base.Cfg.RoomServer,
|
Cfg: &r.Base.Cfg.RoomServer,
|
||||||
Base: r.Base,
|
Base: r.Base,
|
||||||
|
|
@ -114,7 +120,8 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio
|
||||||
JetStream: r.JetStream,
|
JetStream: r.JetStream,
|
||||||
NATSClient: r.NATSClient,
|
NATSClient: r.NATSClient,
|
||||||
Durable: nats.Durable(r.Durable),
|
Durable: nats.Durable(r.Durable),
|
||||||
ServerName: r.Cfg.Matrix.ServerName,
|
ServerName: r.ServerName,
|
||||||
|
SigningIdentity: identity,
|
||||||
FSAPI: fsAPI,
|
FSAPI: fsAPI,
|
||||||
KeyRing: keyRing,
|
KeyRing: keyRing,
|
||||||
ACLs: r.ServerACLs,
|
ACLs: r.ServerACLs,
|
||||||
|
|
@ -135,7 +142,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio
|
||||||
Queryer: r.Queryer,
|
Queryer: r.Queryer,
|
||||||
}
|
}
|
||||||
r.Peeker = &perform.Peeker{
|
r.Peeker = &perform.Peeker{
|
||||||
ServerName: r.Cfg.Matrix.ServerName,
|
ServerName: r.ServerName,
|
||||||
Cfg: r.Cfg,
|
Cfg: r.Cfg,
|
||||||
DB: r.DB,
|
DB: r.DB,
|
||||||
FSAPI: r.fsAPI,
|
FSAPI: r.fsAPI,
|
||||||
|
|
@ -146,7 +153,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio
|
||||||
Inputer: r.Inputer,
|
Inputer: r.Inputer,
|
||||||
}
|
}
|
||||||
r.Unpeeker = &perform.Unpeeker{
|
r.Unpeeker = &perform.Unpeeker{
|
||||||
ServerName: r.Cfg.Matrix.ServerName,
|
ServerName: r.ServerName,
|
||||||
Cfg: r.Cfg,
|
Cfg: r.Cfg,
|
||||||
DB: r.DB,
|
DB: r.DB,
|
||||||
FSAPI: r.fsAPI,
|
FSAPI: r.fsAPI,
|
||||||
|
|
@ -193,6 +200,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio
|
||||||
|
|
||||||
func (r *RoomserverInternalAPI) SetUserAPI(userAPI userapi.RoomserverUserAPI) {
|
func (r *RoomserverInternalAPI) SetUserAPI(userAPI userapi.RoomserverUserAPI) {
|
||||||
r.Leaver.UserAPI = userAPI
|
r.Leaver.UserAPI = userAPI
|
||||||
|
r.Inputer.UserAPI = userAPI
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *RoomserverInternalAPI) SetAppserviceAPI(asAPI asAPI.AppServiceInternalAPI) {
|
func (r *RoomserverInternalAPI) SetAppserviceAPI(asAPI asAPI.AppServiceInternalAPI) {
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,8 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
|
||||||
"github.com/Arceliar/phony"
|
"github.com/Arceliar/phony"
|
||||||
"github.com/getsentry/sentry-go"
|
"github.com/getsentry/sentry-go"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
|
@ -79,6 +81,7 @@ type Inputer struct {
|
||||||
JetStream nats.JetStreamContext
|
JetStream nats.JetStreamContext
|
||||||
Durable nats.SubOpt
|
Durable nats.SubOpt
|
||||||
ServerName gomatrixserverlib.ServerName
|
ServerName gomatrixserverlib.ServerName
|
||||||
|
SigningIdentity *gomatrixserverlib.SigningIdentity
|
||||||
FSAPI fedapi.RoomserverFederationAPI
|
FSAPI fedapi.RoomserverFederationAPI
|
||||||
KeyRing gomatrixserverlib.JSONVerifier
|
KeyRing gomatrixserverlib.JSONVerifier
|
||||||
ACLs *acls.ServerACLs
|
ACLs *acls.ServerACLs
|
||||||
|
|
@ -87,6 +90,7 @@ type Inputer struct {
|
||||||
workers sync.Map // room ID -> *worker
|
workers sync.Map // room ID -> *worker
|
||||||
|
|
||||||
Queryer *query.Queryer
|
Queryer *query.Queryer
|
||||||
|
UserAPI userapi.RoomserverUserAPI
|
||||||
}
|
}
|
||||||
|
|
||||||
// If a room consumer is inactive for a while then we will allow NATS
|
// If a room consumer is inactive for a while then we will allow NATS
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,7 @@ package input
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
@ -31,6 +32,8 @@ import (
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
userAPI "github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
|
||||||
fedapi "github.com/matrix-org/dendrite/federationapi/api"
|
fedapi "github.com/matrix-org/dendrite/federationapi/api"
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
"github.com/matrix-org/dendrite/internal/eventutil"
|
"github.com/matrix-org/dendrite/internal/eventutil"
|
||||||
|
|
@ -440,6 +443,13 @@ func (r *Inputer) processRoomEvent(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If guest_access changed and is not can_join, kick all guest users.
|
||||||
|
if event.Type() == gomatrixserverlib.MRoomGuestAccess && gjson.GetBytes(event.Content(), "guest_access").Str != "can_join" {
|
||||||
|
if err = r.kickGuests(ctx, event, roomInfo); err != nil {
|
||||||
|
logrus.WithError(err).Error("failed to kick guest users on m.room.guest_access revocation")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Everything was OK — the latest events updater didn't error and
|
// Everything was OK — the latest events updater didn't error and
|
||||||
// we've sent output events. Finally, generate a hook call.
|
// we've sent output events. Finally, generate a hook call.
|
||||||
hooks.Run(hooks.KindNewEventPersisted, headered)
|
hooks.Run(hooks.KindNewEventPersisted, headered)
|
||||||
|
|
@ -729,3 +739,98 @@ func (r *Inputer) calculateAndSetState(
|
||||||
succeeded = true
|
succeeded = true
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// kickGuests kicks guests users from m.room.guest_access rooms, if guest access is now prohibited.
|
||||||
|
func (r *Inputer) kickGuests(ctx context.Context, event *gomatrixserverlib.Event, roomInfo *types.RoomInfo) error {
|
||||||
|
membershipNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, true)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
memberEvents, err := r.DB.Events(ctx, membershipNIDs)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
inputEvents := make([]api.InputRoomEvent, 0, len(memberEvents))
|
||||||
|
latestReq := &api.QueryLatestEventsAndStateRequest{
|
||||||
|
RoomID: event.RoomID(),
|
||||||
|
}
|
||||||
|
latestRes := &api.QueryLatestEventsAndStateResponse{}
|
||||||
|
if err = r.Queryer.QueryLatestEventsAndState(ctx, latestReq, latestRes); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
prevEvents := latestRes.LatestEvents
|
||||||
|
for _, memberEvent := range memberEvents {
|
||||||
|
if memberEvent.StateKey() == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
localpart, senderDomain, err := gomatrixserverlib.SplitID('@', *memberEvent.StateKey())
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
accountRes := &userAPI.QueryAccountByLocalpartResponse{}
|
||||||
|
if err = r.UserAPI.QueryAccountByLocalpart(ctx, &userAPI.QueryAccountByLocalpartRequest{
|
||||||
|
Localpart: localpart,
|
||||||
|
ServerName: senderDomain,
|
||||||
|
}, accountRes); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if accountRes.Account == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if accountRes.Account.AccountType != userAPI.AccountTypeGuest {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var memberContent gomatrixserverlib.MemberContent
|
||||||
|
if err = json.Unmarshal(memberEvent.Content(), &memberContent); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
memberContent.Membership = gomatrixserverlib.Leave
|
||||||
|
|
||||||
|
stateKey := *memberEvent.StateKey()
|
||||||
|
fledglingEvent := &gomatrixserverlib.EventBuilder{
|
||||||
|
RoomID: event.RoomID(),
|
||||||
|
Type: gomatrixserverlib.MRoomMember,
|
||||||
|
StateKey: &stateKey,
|
||||||
|
Sender: stateKey,
|
||||||
|
PrevEvents: prevEvents,
|
||||||
|
}
|
||||||
|
|
||||||
|
if fledglingEvent.Content, err = json.Marshal(memberContent); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(fledglingEvent)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
event, err := eventutil.BuildEvent(ctx, fledglingEvent, r.Cfg.Matrix, r.SigningIdentity, time.Now(), &eventsNeeded, latestRes)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
inputEvents = append(inputEvents, api.InputRoomEvent{
|
||||||
|
Kind: api.KindNew,
|
||||||
|
Event: event,
|
||||||
|
Origin: senderDomain,
|
||||||
|
SendAsServer: string(senderDomain),
|
||||||
|
})
|
||||||
|
prevEvents = []gomatrixserverlib.EventReference{
|
||||||
|
event.EventReference(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inputReq := &api.InputRoomEventsRequest{
|
||||||
|
InputRoomEvents: inputEvents,
|
||||||
|
Asynchronous: true, // Needs to be async, as we otherwise create a deadlock
|
||||||
|
}
|
||||||
|
inputRes := &api.InputRoomEventsResponse{}
|
||||||
|
return r.InputRoomEvents(ctx, inputReq, inputRes)
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,7 @@ package perform
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
@ -270,6 +271,28 @@ func (r *Joiner) performJoinRoomByID(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If a guest is trying to join a room, check that the room has a m.room.guest_access event
|
||||||
|
if req.IsGuest {
|
||||||
|
var guestAccessEvent *gomatrixserverlib.HeaderedEvent
|
||||||
|
guestAccess := "forbidden"
|
||||||
|
guestAccessEvent, err = r.DB.GetStateEvent(ctx, req.RoomIDOrAlias, gomatrixserverlib.MRoomGuestAccess, "")
|
||||||
|
if (err != nil && !errors.Is(err, sql.ErrNoRows)) || guestAccessEvent == nil {
|
||||||
|
logrus.WithError(err).Warn("unable to get m.room.guest_access event, defaulting to 'forbidden'")
|
||||||
|
}
|
||||||
|
if guestAccessEvent != nil {
|
||||||
|
guestAccess = gjson.GetBytes(guestAccessEvent.Content(), "guest_access").String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Servers MUST only allow guest users to join rooms if the m.room.guest_access state event
|
||||||
|
// is present on the room and has the guest_access value can_join.
|
||||||
|
if guestAccess != "can_join" {
|
||||||
|
return "", "", &rsAPI.PerformError{
|
||||||
|
Code: rsAPI.PerformErrorNotAllowed,
|
||||||
|
Msg: "Guest access is forbidden",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// If we should do a forced federated join then do that.
|
// If we should do a forced federated join then do that.
|
||||||
var joinedVia gomatrixserverlib.ServerName
|
var joinedVia gomatrixserverlib.ServerName
|
||||||
if forceFederatedJoin {
|
if forceFederatedJoin {
|
||||||
|
|
|
||||||
|
|
@ -3,23 +3,27 @@ package roomserver_test
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
|
"github.com/matrix-org/dendrite/internal/httputil"
|
||||||
|
"github.com/matrix-org/dendrite/setup/base"
|
||||||
|
"github.com/matrix-org/dendrite/userapi"
|
||||||
|
|
||||||
|
userAPI "github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/federationapi"
|
"github.com/matrix-org/dendrite/federationapi"
|
||||||
"github.com/matrix-org/dendrite/keyserver"
|
"github.com/matrix-org/dendrite/keyserver"
|
||||||
"github.com/matrix-org/dendrite/setup/jetstream"
|
"github.com/matrix-org/dendrite/setup/jetstream"
|
||||||
"github.com/matrix-org/dendrite/syncapi"
|
"github.com/matrix-org/dendrite/syncapi"
|
||||||
"github.com/matrix-org/dendrite/userapi"
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/httputil"
|
|
||||||
"github.com/matrix-org/dendrite/roomserver"
|
"github.com/matrix-org/dendrite/roomserver"
|
||||||
"github.com/matrix-org/dendrite/roomserver/api"
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
"github.com/matrix-org/dendrite/roomserver/inthttp"
|
"github.com/matrix-org/dendrite/roomserver/inthttp"
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage"
|
"github.com/matrix-org/dendrite/roomserver/storage"
|
||||||
"github.com/matrix-org/dendrite/setup/base"
|
|
||||||
"github.com/matrix-org/dendrite/test"
|
"github.com/matrix-org/dendrite/test"
|
||||||
"github.com/matrix-org/dendrite/test/testrig"
|
"github.com/matrix-org/dendrite/test/testrig"
|
||||||
)
|
)
|
||||||
|
|
@ -34,7 +38,28 @@ func mustCreateDatabase(t *testing.T, dbType test.DBType) (*base.BaseDendrite, s
|
||||||
return base, db, close
|
return base, db, close
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_SharedUsers(t *testing.T) {
|
func TestUsers(t *testing.T) {
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
base, close := testrig.CreateBaseDendrite(t, dbType)
|
||||||
|
defer close()
|
||||||
|
rsAPI := roomserver.NewInternalAPI(base)
|
||||||
|
// SetFederationAPI starts the room event input consumer
|
||||||
|
rsAPI.SetFederationAPI(nil, nil)
|
||||||
|
|
||||||
|
t.Run("shared users", func(t *testing.T) {
|
||||||
|
testSharedUsers(t, rsAPI)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("kick users", func(t *testing.T) {
|
||||||
|
usrAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, nil, rsAPI, nil)
|
||||||
|
rsAPI.SetUserAPI(usrAPI)
|
||||||
|
testKickUsers(t, rsAPI, usrAPI)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func testSharedUsers(t *testing.T, rsAPI api.RoomserverInternalAPI) {
|
||||||
alice := test.NewUser(t)
|
alice := test.NewUser(t)
|
||||||
bob := test.NewUser(t)
|
bob := test.NewUser(t)
|
||||||
room := test.NewRoom(t, alice, test.RoomPreset(test.PresetTrustedPrivateChat))
|
room := test.NewRoom(t, alice, test.RoomPreset(test.PresetTrustedPrivateChat))
|
||||||
|
|
@ -48,36 +73,93 @@ func Test_SharedUsers(t *testing.T) {
|
||||||
}, test.WithStateKey(bob.ID))
|
}, test.WithStateKey(bob.ID))
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
|
||||||
base, _, close := mustCreateDatabase(t, dbType)
|
|
||||||
defer close()
|
|
||||||
|
|
||||||
rsAPI := roomserver.NewInternalAPI(base)
|
// Create the room
|
||||||
// SetFederationAPI starts the room event input consumer
|
if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil {
|
||||||
rsAPI.SetFederationAPI(nil, nil)
|
t.Errorf("failed to send events: %v", err)
|
||||||
// Create the room
|
}
|
||||||
if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil {
|
|
||||||
t.Fatalf("failed to send events: %v", err)
|
// Query the shared users for Alice, there should only be Bob.
|
||||||
|
// This is used by the SyncAPI keychange consumer.
|
||||||
|
res := &api.QuerySharedUsersResponse{}
|
||||||
|
if err := rsAPI.QuerySharedUsers(ctx, &api.QuerySharedUsersRequest{UserID: alice.ID}, res); err != nil {
|
||||||
|
t.Errorf("unable to query known users: %v", err)
|
||||||
|
}
|
||||||
|
if _, ok := res.UserIDsToCount[bob.ID]; !ok {
|
||||||
|
t.Errorf("expected to find %s in shared users, but didn't: %+v", bob.ID, res.UserIDsToCount)
|
||||||
|
}
|
||||||
|
// Also verify that we get the expected result when specifying OtherUserIDs.
|
||||||
|
// This is used by the SyncAPI when getting device list changes.
|
||||||
|
if err := rsAPI.QuerySharedUsers(ctx, &api.QuerySharedUsersRequest{UserID: alice.ID, OtherUserIDs: []string{bob.ID}}, res); err != nil {
|
||||||
|
t.Errorf("unable to query known users: %v", err)
|
||||||
|
}
|
||||||
|
if _, ok := res.UserIDsToCount[bob.ID]; !ok {
|
||||||
|
t.Errorf("expected to find %s in shared users, but didn't: %+v", bob.ID, res.UserIDsToCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testKickUsers(t *testing.T, rsAPI api.RoomserverInternalAPI, usrAPI userAPI.UserInternalAPI) {
|
||||||
|
// Create users and room; Bob is going to be the guest and kicked on revocation of guest access
|
||||||
|
alice := test.NewUser(t, test.WithAccountType(userAPI.AccountTypeUser))
|
||||||
|
bob := test.NewUser(t, test.WithAccountType(userAPI.AccountTypeGuest))
|
||||||
|
|
||||||
|
room := test.NewRoom(t, alice, test.RoomPreset(test.PresetPublicChat), test.GuestsCanJoin(true))
|
||||||
|
|
||||||
|
// Join with the guest user
|
||||||
|
room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{
|
||||||
|
"membership": "join",
|
||||||
|
}, test.WithStateKey(bob.ID))
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create the users in the userapi, so the RSAPI can query the account type later
|
||||||
|
for _, u := range []*test.User{alice, bob} {
|
||||||
|
localpart, serverName, _ := gomatrixserverlib.SplitID('@', u.ID)
|
||||||
|
userRes := &userAPI.PerformAccountCreationResponse{}
|
||||||
|
if err := usrAPI.PerformAccountCreation(ctx, &userAPI.PerformAccountCreationRequest{
|
||||||
|
AccountType: u.AccountType,
|
||||||
|
Localpart: localpart,
|
||||||
|
ServerName: serverName,
|
||||||
|
Password: "someRandomPassword",
|
||||||
|
}, userRes); err != nil {
|
||||||
|
t.Errorf("failed to create account: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the room in the database
|
||||||
|
if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil {
|
||||||
|
t.Errorf("failed to send events: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the membership events BEFORE revoking guest access
|
||||||
|
membershipRes := &api.QueryMembershipsForRoomResponse{}
|
||||||
|
if err := rsAPI.QueryMembershipsForRoom(ctx, &api.QueryMembershipsForRoomRequest{LocalOnly: true, JoinedOnly: true, RoomID: room.ID}, membershipRes); err != nil {
|
||||||
|
t.Errorf("failed to query membership for room: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// revoke guest access
|
||||||
|
revokeEvent := room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomGuestAccess, map[string]string{"guest_access": "forbidden"}, test.WithStateKey(""))
|
||||||
|
if err := api.SendEvents(ctx, rsAPI, api.KindNew, []*gomatrixserverlib.HeaderedEvent{revokeEvent}, "test", "test", "test", nil, false); err != nil {
|
||||||
|
t.Errorf("failed to send events: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Even though we are sending the events sync, the "kickUsers" function is sending the events async, so we need
|
||||||
|
// to loop and wait for the events to be processed by the roomserver.
|
||||||
|
for i := 0; i <= 20; i++ {
|
||||||
|
// Get the membership events AFTER revoking guest access
|
||||||
|
membershipRes2 := &api.QueryMembershipsForRoomResponse{}
|
||||||
|
if err := rsAPI.QueryMembershipsForRoom(ctx, &api.QueryMembershipsForRoomRequest{LocalOnly: true, JoinedOnly: true, RoomID: room.ID}, membershipRes2); err != nil {
|
||||||
|
t.Errorf("failed to query membership for room: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Query the shared users for Alice, there should only be Bob.
|
// The membership events should NOT match, as Bob (guest user) should now be kicked from the room
|
||||||
// This is used by the SyncAPI keychange consumer.
|
if !reflect.DeepEqual(membershipRes, membershipRes2) {
|
||||||
res := &api.QuerySharedUsersResponse{}
|
return
|
||||||
if err := rsAPI.QuerySharedUsers(ctx, &api.QuerySharedUsersRequest{UserID: alice.ID}, res); err != nil {
|
|
||||||
t.Fatalf("unable to query known users: %v", err)
|
|
||||||
}
|
}
|
||||||
if _, ok := res.UserIDsToCount[bob.ID]; !ok {
|
time.Sleep(time.Millisecond * 10)
|
||||||
t.Fatalf("expected to find %s in shared users, but didn't: %+v", bob.ID, res.UserIDsToCount)
|
}
|
||||||
}
|
|
||||||
// Also verify that we get the expected result when specifying OtherUserIDs.
|
t.Errorf("memberships didn't change in time")
|
||||||
// This is used by the SyncAPI when getting device list changes.
|
|
||||||
if err := rsAPI.QuerySharedUsers(ctx, &api.QuerySharedUsersRequest{UserID: alice.ID, OtherUserIDs: []string{bob.ID}}, res); err != nil {
|
|
||||||
t.Fatalf("unable to query known users: %v", err)
|
|
||||||
}
|
|
||||||
if _, ok := res.UserIDsToCount[bob.ID]; !ok {
|
|
||||||
t.Fatalf("expected to find %s in shared users, but didn't: %+v", bob.ID, res.UserIDsToCount)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_QueryLeftUsers(t *testing.T) {
|
func Test_QueryLeftUsers(t *testing.T) {
|
||||||
|
|
|
||||||
|
|
@ -174,7 +174,7 @@ func (c *Global) SigningIdentityFor(serverName gomatrixserverlib.ServerName) (*g
|
||||||
return id, nil
|
return id, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("no signing identity %q", serverName)
|
return nil, fmt.Errorf("no signing identity for %q", serverName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Global) SigningIdentities() []*gomatrixserverlib.SigningIdentity {
|
func (c *Global) SigningIdentities() []*gomatrixserverlib.SigningIdentity {
|
||||||
|
|
|
||||||
|
|
@ -16,8 +16,10 @@ package config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"gopkg.in/yaml.v2"
|
"gopkg.in/yaml.v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -290,3 +292,55 @@ func TestUnmarshalDataUnit(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_SigningIdentityFor(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
virtualHosts []*VirtualHost
|
||||||
|
serverName gomatrixserverlib.ServerName
|
||||||
|
want *gomatrixserverlib.SigningIdentity
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no virtual hosts defined",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no identity found",
|
||||||
|
serverName: gomatrixserverlib.ServerName("doesnotexist"),
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "found identity",
|
||||||
|
serverName: gomatrixserverlib.ServerName("main"),
|
||||||
|
want: &gomatrixserverlib.SigningIdentity{ServerName: "main"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "identity found on virtual hosts",
|
||||||
|
serverName: gomatrixserverlib.ServerName("vh2"),
|
||||||
|
virtualHosts: []*VirtualHost{
|
||||||
|
{SigningIdentity: gomatrixserverlib.SigningIdentity{ServerName: "vh1"}},
|
||||||
|
{SigningIdentity: gomatrixserverlib.SigningIdentity{ServerName: "vh2"}},
|
||||||
|
},
|
||||||
|
want: &gomatrixserverlib.SigningIdentity{ServerName: "vh2"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
c := &Global{
|
||||||
|
VirtualHosts: tt.virtualHosts,
|
||||||
|
SigningIdentity: gomatrixserverlib.SigningIdentity{
|
||||||
|
ServerName: "main",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
got, err := c.SigningIdentityFor(tt.serverName)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("SigningIdentityFor() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("SigningIdentityFor() got = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -49,3 +49,6 @@ Leaves are present in non-gapped incremental syncs
|
||||||
|
|
||||||
# Below test was passing for the wrong reason, failing correctly since #2858
|
# Below test was passing for the wrong reason, failing correctly since #2858
|
||||||
New federated private chats get full presence information (SYN-115)
|
New federated private chats get full presence information (SYN-115)
|
||||||
|
|
||||||
|
# We don't have any state to calculate m.room.guest_access when accepting invites
|
||||||
|
Guest users can accept invites to private rooms over federation
|
||||||
|
|
@ -764,3 +764,6 @@ local user has tags copied to the new room
|
||||||
remote user has tags copied to the new room
|
remote user has tags copied to the new room
|
||||||
/upgrade moves remote aliases to the new room
|
/upgrade moves remote aliases to the new room
|
||||||
Local and remote users' homeservers remove a room from their public directory on upgrade
|
Local and remote users' homeservers remove a room from their public directory on upgrade
|
||||||
|
Guest users denied access over federation if guest access prohibited
|
||||||
|
Guest users are kicked from guest_access rooms on revocation of guest_access
|
||||||
|
Guest users are kicked from guest_access rooms on revocation of guest_access over federation
|
||||||
22
test/room.go
22
test/room.go
|
|
@ -38,11 +38,12 @@ var (
|
||||||
)
|
)
|
||||||
|
|
||||||
type Room struct {
|
type Room struct {
|
||||||
ID string
|
ID string
|
||||||
Version gomatrixserverlib.RoomVersion
|
Version gomatrixserverlib.RoomVersion
|
||||||
preset Preset
|
preset Preset
|
||||||
visibility gomatrixserverlib.HistoryVisibility
|
guestCanJoin bool
|
||||||
creator *User
|
visibility gomatrixserverlib.HistoryVisibility
|
||||||
|
creator *User
|
||||||
|
|
||||||
authEvents gomatrixserverlib.AuthEvents
|
authEvents gomatrixserverlib.AuthEvents
|
||||||
currentState map[string]*gomatrixserverlib.HeaderedEvent
|
currentState map[string]*gomatrixserverlib.HeaderedEvent
|
||||||
|
|
@ -120,6 +121,11 @@ func (r *Room) insertCreateEvents(t *testing.T) {
|
||||||
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomPowerLevels, plContent, WithStateKey(""))
|
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomPowerLevels, plContent, WithStateKey(""))
|
||||||
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomJoinRules, joinRule, WithStateKey(""))
|
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomJoinRules, joinRule, WithStateKey(""))
|
||||||
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomHistoryVisibility, hisVis, WithStateKey(""))
|
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomHistoryVisibility, hisVis, WithStateKey(""))
|
||||||
|
if r.guestCanJoin {
|
||||||
|
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomGuestAccess, map[string]string{
|
||||||
|
"guest_access": "can_join",
|
||||||
|
}, WithStateKey(""))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create an event in this room but do not insert it. Does not modify the room in any way (depth, fwd extremities, etc) so is thread-safe.
|
// Create an event in this room but do not insert it. Does not modify the room in any way (depth, fwd extremities, etc) so is thread-safe.
|
||||||
|
|
@ -268,3 +274,9 @@ func RoomVersion(ver gomatrixserverlib.RoomVersion) roomModifier {
|
||||||
r.Version = ver
|
r.Version = ver
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GuestsCanJoin(canJoin bool) roomModifier {
|
||||||
|
return func(t *testing.T, r *Room) {
|
||||||
|
r.guestCanJoin = canJoin
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -47,7 +47,7 @@ var (
|
||||||
|
|
||||||
type User struct {
|
type User struct {
|
||||||
ID string
|
ID string
|
||||||
accountType api.AccountType
|
AccountType api.AccountType
|
||||||
// key ID and private key of the server who has this user, if known.
|
// key ID and private key of the server who has this user, if known.
|
||||||
keyID gomatrixserverlib.KeyID
|
keyID gomatrixserverlib.KeyID
|
||||||
privKey ed25519.PrivateKey
|
privKey ed25519.PrivateKey
|
||||||
|
|
@ -66,7 +66,7 @@ func WithSigningServer(srvName gomatrixserverlib.ServerName, keyID gomatrixserve
|
||||||
|
|
||||||
func WithAccountType(accountType api.AccountType) UserOpt {
|
func WithAccountType(accountType api.AccountType) UserOpt {
|
||||||
return func(u *User) {
|
return func(u *User) {
|
||||||
u.accountType = accountType
|
u.AccountType = accountType
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -50,6 +50,7 @@ type KeyserverUserAPI interface {
|
||||||
|
|
||||||
type RoomserverUserAPI interface {
|
type RoomserverUserAPI interface {
|
||||||
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
|
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
|
||||||
|
QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) (err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// api functions required by the media api
|
// api functions required by the media api
|
||||||
|
|
@ -671,3 +672,12 @@ type PerformSaveThreePIDAssociationRequest struct {
|
||||||
ServerName gomatrixserverlib.ServerName
|
ServerName gomatrixserverlib.ServerName
|
||||||
Medium string
|
Medium string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type QueryAccountByLocalpartRequest struct {
|
||||||
|
Localpart string
|
||||||
|
ServerName gomatrixserverlib.ServerName
|
||||||
|
}
|
||||||
|
|
||||||
|
type QueryAccountByLocalpartResponse struct {
|
||||||
|
Account *Account
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -204,6 +204,12 @@ func (t *UserInternalAPITrace) PerformSaveThreePIDAssociation(ctx context.Contex
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *UserInternalAPITrace) QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) error {
|
||||||
|
err := t.Impl.QueryAccountByLocalpart(ctx, req, res)
|
||||||
|
util.GetLogger(ctx).Infof("QueryAccountByLocalpart req=%+v res=%+v", js(req), js(res))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func js(thing interface{}) string {
|
func js(thing interface{}) string {
|
||||||
b, err := json.Marshal(thing)
|
b, err := json.Marshal(thing)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
||||||
|
|
@ -548,6 +548,11 @@ func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAc
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *UserInternalAPI) QueryAccountByLocalpart(ctx context.Context, req *api.QueryAccountByLocalpartRequest, res *api.QueryAccountByLocalpartResponse) (err error) {
|
||||||
|
res.Account, err = a.DB.GetAccountByLocalpart(ctx, req.Localpart, req.ServerName)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Return the appservice 'device' or nil if the token is not an appservice. Returns an error if there was a problem
|
// Return the appservice 'device' or nil if the token is not an appservice. Returns an error if there was a problem
|
||||||
// creating a 'device'.
|
// creating a 'device'.
|
||||||
func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appServiceUserID string) (*api.Device, error) {
|
func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appServiceUserID string) (*api.Device, error) {
|
||||||
|
|
|
||||||
|
|
@ -60,6 +60,7 @@ const (
|
||||||
QueryAccountByPasswordPath = "/userapi/queryAccountByPassword"
|
QueryAccountByPasswordPath = "/userapi/queryAccountByPassword"
|
||||||
QueryLocalpartForThreePIDPath = "/userapi/queryLocalpartForThreePID"
|
QueryLocalpartForThreePIDPath = "/userapi/queryLocalpartForThreePID"
|
||||||
QueryThreePIDsForLocalpartPath = "/userapi/queryThreePIDsForLocalpart"
|
QueryThreePIDsForLocalpartPath = "/userapi/queryThreePIDsForLocalpart"
|
||||||
|
QueryAccountByLocalpartPath = "/userapi/queryAccountType"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API.
|
// NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API.
|
||||||
|
|
@ -440,3 +441,14 @@ func (h *httpUserInternalAPI) PerformSaveThreePIDAssociation(
|
||||||
h.httpClient, ctx, request, response,
|
h.httpClient, ctx, request, response,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *httpUserInternalAPI) QueryAccountByLocalpart(
|
||||||
|
ctx context.Context,
|
||||||
|
req *api.QueryAccountByLocalpartRequest,
|
||||||
|
res *api.QueryAccountByLocalpartResponse,
|
||||||
|
) error {
|
||||||
|
return httputil.CallInternalRPCAPI(
|
||||||
|
"QueryAccountByLocalpart", h.apiURL+QueryAccountByLocalpartPath,
|
||||||
|
h.httpClient, ctx, req, res,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -189,4 +189,9 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI, enableMetrics
|
||||||
PerformSaveThreePIDAssociationPath,
|
PerformSaveThreePIDAssociationPath,
|
||||||
httputil.MakeInternalRPCAPI("UserAPIPerformSaveThreePIDAssociation", enableMetrics, s.PerformSaveThreePIDAssociation),
|
httputil.MakeInternalRPCAPI("UserAPIPerformSaveThreePIDAssociation", enableMetrics, s.PerformSaveThreePIDAssociation),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
internalAPIMux.Handle(
|
||||||
|
QueryAccountByLocalpartPath,
|
||||||
|
httputil.MakeInternalRPCAPI("AccountByLocalpart", enableMetrics, s.QueryAccountByLocalpart),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -307,3 +307,64 @@ func TestLoginToken(t *testing.T) {
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestQueryAccountByLocalpart(t *testing.T) {
|
||||||
|
alice := test.NewUser(t)
|
||||||
|
|
||||||
|
localpart, userServername, _ := gomatrixserverlib.SplitID('@', alice.ID)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
intAPI, db, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType)
|
||||||
|
defer close()
|
||||||
|
|
||||||
|
createdAcc, err := db.CreateAccount(ctx, localpart, userServername, "", "", alice.AccountType)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := func(t *testing.T, internalAPI api.UserInternalAPI) {
|
||||||
|
// Query existing account
|
||||||
|
queryAccResp := &api.QueryAccountByLocalpartResponse{}
|
||||||
|
if err = internalAPI.QueryAccountByLocalpart(ctx, &api.QueryAccountByLocalpartRequest{
|
||||||
|
Localpart: localpart,
|
||||||
|
ServerName: userServername,
|
||||||
|
}, queryAccResp); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(createdAcc, queryAccResp.Account) {
|
||||||
|
t.Fatalf("created and queried accounts don't match:\n%+v vs.\n%+v", createdAcc, queryAccResp.Account)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query non-existent account, this should result in an error
|
||||||
|
err = internalAPI.QueryAccountByLocalpart(ctx, &api.QueryAccountByLocalpartRequest{
|
||||||
|
Localpart: "doesnotexist",
|
||||||
|
ServerName: userServername,
|
||||||
|
}, queryAccResp)
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("expected an error, but got none: %+v", queryAccResp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("Monolith", func(t *testing.T) {
|
||||||
|
testCases(t, intAPI)
|
||||||
|
// also test tracing
|
||||||
|
testCases(t, &api.UserInternalAPITrace{Impl: intAPI})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("HTTP API", func(t *testing.T) {
|
||||||
|
router := mux.NewRouter().PathPrefix(httputil.InternalPathPrefix).Subrouter()
|
||||||
|
userapi.AddInternalRoutes(router, intAPI, false)
|
||||||
|
apiURL, cancel := test.ListenAndServe(t, router, false)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
userHTTPApi, err := inthttp.NewUserAPIClient(apiURL, &http.Client{Timeout: time.Second * 5})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create HTTP client: %s", err)
|
||||||
|
}
|
||||||
|
testCases(t, userHTTPApi)
|
||||||
|
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue