Merge branch 'main' of github.com:matrix-org/dendrite into neilalexander/purgeroom

This commit is contained in:
Till Faelligen 2022-12-22 13:13:42 +01:00
commit 2b7d1023ba
No known key found for this signature in database
GPG key ID: ACCDC9606D472758
45 changed files with 1349 additions and 305 deletions

134
clientapi/admin_test.go Normal file
View file

@ -0,0 +1,134 @@
package clientapi
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/keyserver"
"github.com/matrix-org/dendrite/roomserver"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/tidwall/gjson"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig"
"github.com/matrix-org/dendrite/userapi"
uapi "github.com/matrix-org/dendrite/userapi/api"
)
func TestAdminResetPassword(t *testing.T) {
aliceAdmin := test.NewUser(t, test.WithAccountType(uapi.AccountTypeAdmin))
bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser))
vhUser := &test.User{ID: "@vhuser:vh1"}
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
base, baseClose := testrig.CreateBaseDendrite(t, dbType)
defer baseClose()
// add a vhost
base.Cfg.Global.VirtualHosts = append(base.Cfg.Global.VirtualHosts, &config.VirtualHost{
SigningIdentity: gomatrixserverlib.SigningIdentity{ServerName: "vh1"},
})
rsAPI := roomserver.NewInternalAPI(base)
// Needed for changing the password/login
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI)
userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil)
keyAPI.SetUserAPI(userAPI)
// We mostly need the userAPI for this test, so nil for other APIs/caches etc.
AddPublicRoutes(base, nil, nil, nil, nil, nil, userAPI, nil, nil, nil)
// Create the users in the userapi and login
accessTokens := map[*test.User]string{
aliceAdmin: "",
bob: "",
vhUser: "",
}
for u := range accessTokens {
localpart, serverName, _ := gomatrixserverlib.SplitID('@', u.ID)
userRes := &uapi.PerformAccountCreationResponse{}
password := util.RandomString(8)
if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{
AccountType: u.AccountType,
Localpart: localpart,
ServerName: serverName,
Password: password,
}, userRes); err != nil {
t.Errorf("failed to create account: %s", err)
}
req := test.NewRequest(t, http.MethodPost, "/_matrix/client/v3/login", test.WithJSONBody(t, map[string]interface{}{
"type": authtypes.LoginTypePassword,
"identifier": map[string]interface{}{
"type": "m.id.user",
"user": u.ID,
},
"password": password,
}))
rec := httptest.NewRecorder()
base.PublicClientAPIMux.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("failed to login: %s", rec.Body.String())
}
accessTokens[u] = gjson.GetBytes(rec.Body.Bytes(), "access_token").String()
}
testCases := []struct {
name string
requestingUser *test.User
userID string
requestOpt test.HTTPRequestOpt
wantOK bool
withHeader bool
}{
{name: "Missing auth", requestingUser: bob, wantOK: false, userID: bob.ID},
{name: "Bob is denied access", requestingUser: bob, wantOK: false, withHeader: true, userID: bob.ID},
{name: "Alice is allowed access", requestingUser: aliceAdmin, wantOK: true, withHeader: true, userID: bob.ID, requestOpt: test.WithJSONBody(t, map[string]interface{}{
"password": util.RandomString(8),
})},
{name: "missing userID does not call function", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: ""}, // this 404s
{name: "rejects empty password", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: bob.ID, requestOpt: test.WithJSONBody(t, map[string]interface{}{
"password": "",
})},
{name: "rejects unknown server name", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: "@doesnotexist:localhost", requestOpt: test.WithJSONBody(t, map[string]interface{}{})},
{name: "rejects unknown user", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: "@doesnotexist:test", requestOpt: test.WithJSONBody(t, map[string]interface{}{})},
{name: "allows changing password for different vhost", requestingUser: aliceAdmin, wantOK: true, withHeader: true, userID: vhUser.ID, requestOpt: test.WithJSONBody(t, map[string]interface{}{
"password": util.RandomString(8),
})},
{name: "rejects existing user, missing body", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: bob.ID},
{name: "rejects invalid userID", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: "!notauserid:test", requestOpt: test.WithJSONBody(t, map[string]interface{}{})},
{name: "rejects invalid json", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: bob.ID, requestOpt: test.WithJSONBody(t, `{invalidJSON}`)},
{name: "rejects too weak password", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: bob.ID, requestOpt: test.WithJSONBody(t, map[string]interface{}{
"password": util.RandomString(6),
})},
{name: "rejects too long password", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: bob.ID, requestOpt: test.WithJSONBody(t, map[string]interface{}{
"password": util.RandomString(513),
})},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req := test.NewRequest(t, http.MethodPost, "/_dendrite/admin/resetPassword/"+tc.userID)
if tc.requestOpt != nil {
req = test.NewRequest(t, http.MethodPost, "/_dendrite/admin/resetPassword/"+tc.userID, tc.requestOpt)
}
if tc.withHeader {
req.Header.Set("Authorization", "Bearer "+accessTokens[tc.requestingUser])
}
rec := httptest.NewRecorder()
base.DendriteAdminMux.ServeHTTP(rec, req)
t.Logf("%s", rec.Body.String())
if tc.wantOK && rec.Code != http.StatusOK {
t.Fatalf("expected http status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String())
}
})
}
})
}

View file

@ -15,6 +15,8 @@
package clientapi package clientapi
import ( import (
"github.com/matrix-org/gomatrixserverlib"
appserviceAPI "github.com/matrix-org/dendrite/appservice/api" appserviceAPI "github.com/matrix-org/dendrite/appservice/api"
"github.com/matrix-org/dendrite/clientapi/api" "github.com/matrix-org/dendrite/clientapi/api"
"github.com/matrix-org/dendrite/clientapi/producers" "github.com/matrix-org/dendrite/clientapi/producers"
@ -26,7 +28,6 @@ import (
"github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/jetstream"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
) )
// AddPublicRoutes sets up and registers HTTP handlers for the ClientAPI component. // AddPublicRoutes sets up and registers HTTP handlers for the ClientAPI component.
@ -57,10 +58,7 @@ func AddPublicRoutes(
} }
routing.Setup( routing.Setup(
base.PublicClientAPIMux, base,
base.PublicWellKnownAPIMux,
base.SynapseAdminMux,
base.DendriteAdminMux,
cfg, rsAPI, asAPI, cfg, rsAPI, asAPI,
userAPI, userDirectoryProvider, federation, userAPI, userDirectoryProvider, federation,
syncProducer, transactionsCache, fsAPI, keyAPI, syncProducer, transactionsCache, fsAPI, keyAPI,

View file

@ -8,6 +8,7 @@ import (
"time" "time"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/nats-io/nats.go" "github.com/nats-io/nats.go"
@ -130,20 +131,40 @@ func AdminPurgeRoom(req *http.Request, cfg *config.ClientAPI, device *userapi.De
} }
func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse { func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse {
if req.Body == nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.Unknown("Missing request body"),
}
}
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
serverName := cfg.Matrix.ServerName var localpart string
localpart, ok := vars["localpart"] userID := vars["userID"]
if !ok { localpart, serverName, err := cfg.Matrix.SplitLocalID('@', userID)
if err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: jsonerror.MissingArgument("Expecting user localpart."), JSON: jsonerror.InvalidArgumentValue(err.Error()),
} }
} }
if l, s, err := cfg.Matrix.SplitLocalID('@', localpart); err == nil { accAvailableResp := &userapi.QueryAccountAvailabilityResponse{}
localpart, serverName = l, s if err = userAPI.QueryAccountAvailability(req.Context(), &userapi.QueryAccountAvailabilityRequest{
Localpart: localpart,
ServerName: serverName,
}, accAvailableResp); err != nil {
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: jsonerror.InternalAPIError(req.Context(), err),
}
}
if accAvailableResp.Available {
return util.JSONResponse{
Code: http.StatusNotFound,
JSON: jsonerror.Unknown("User does not exist"),
}
} }
request := struct { request := struct {
Password string `json:"password"` Password string `json:"password"`
@ -160,6 +181,11 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap
JSON: jsonerror.MissingArgument("Expecting non-empty password."), JSON: jsonerror.MissingArgument("Expecting non-empty password."),
} }
} }
if resErr := internal.ValidatePassword(request.Password); resErr != nil {
return *resErr
}
updateReq := &userapi.PerformPasswordUpdateRequest{ updateReq := &userapi.PerformPasswordUpdateRequest{
Localpart: localpart, Localpart: localpart,
ServerName: serverName, ServerName: serverName,

View file

@ -37,6 +37,7 @@ func JoinRoomByIDOrAlias(
joinReq := roomserverAPI.PerformJoinRequest{ joinReq := roomserverAPI.PerformJoinRequest{
RoomIDOrAlias: roomIDOrAlias, RoomIDOrAlias: roomIDOrAlias,
UserID: device.UserID, UserID: device.UserID,
IsGuest: device.AccountType == api.AccountTypeGuest,
Content: map[string]interface{}{}, Content: map[string]interface{}{},
} }
joinRes := roomserverAPI.PerformJoinResponse{} joinRes := roomserverAPI.PerformJoinResponse{}
@ -84,7 +85,14 @@ func JoinRoomByIDOrAlias(
if err := rsAPI.PerformJoin(req.Context(), &joinReq, &joinRes); err != nil { if err := rsAPI.PerformJoin(req.Context(), &joinReq, &joinRes); err != nil {
done <- jsonerror.InternalAPIError(req.Context(), err) done <- jsonerror.InternalAPIError(req.Context(), err)
} else if joinRes.Error != nil { } else if joinRes.Error != nil {
done <- joinRes.Error.JSONResponse() if joinRes.Error.Code == roomserverAPI.PerformErrorNotAllowed && device.AccountType == api.AccountTypeGuest {
done <- util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.GuestAccessForbidden(joinRes.Error.Msg),
}
} else {
done <- joinRes.Error.JSONResponse()
}
} else { } else {
done <- util.JSONResponse{ done <- util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,

View file

@ -0,0 +1,158 @@
package routing
import (
"bytes"
"context"
"net/http"
"testing"
"time"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/appservice"
"github.com/matrix-org/dendrite/keyserver"
"github.com/matrix-org/dendrite/roomserver"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig"
"github.com/matrix-org/dendrite/userapi"
uapi "github.com/matrix-org/dendrite/userapi/api"
)
func TestJoinRoomByIDOrAlias(t *testing.T) {
alice := test.NewUser(t)
bob := test.NewUser(t)
charlie := test.NewUser(t, test.WithAccountType(uapi.AccountTypeGuest))
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
base, baseClose := testrig.CreateBaseDendrite(t, dbType)
defer baseClose()
rsAPI := roomserver.NewInternalAPI(base)
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI)
userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil)
asAPI := appservice.NewInternalAPI(base, userAPI, rsAPI)
rsAPI.SetFederationAPI(nil, nil) // creates the rs.Inputer etc
// Create the users in the userapi
for _, u := range []*test.User{alice, bob, charlie} {
localpart, serverName, _ := gomatrixserverlib.SplitID('@', u.ID)
userRes := &uapi.PerformAccountCreationResponse{}
if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{
AccountType: u.AccountType,
Localpart: localpart,
ServerName: serverName,
Password: "someRandomPassword",
}, userRes); err != nil {
t.Errorf("failed to create account: %s", err)
}
}
aliceDev := &uapi.Device{UserID: alice.ID}
bobDev := &uapi.Device{UserID: bob.ID}
charlieDev := &uapi.Device{UserID: charlie.ID, AccountType: uapi.AccountTypeGuest}
// create a room with disabled guest access and invite Bob
resp := createRoom(ctx, createRoomRequest{
Name: "testing",
IsDirect: true,
Topic: "testing",
Visibility: "public",
Preset: presetPublicChat,
RoomAliasName: "alias",
Invite: []string{bob.ID},
GuestCanJoin: false,
}, aliceDev, &base.Cfg.ClientAPI, userAPI, rsAPI, asAPI, time.Now())
crResp, ok := resp.JSON.(createRoomResponse)
if !ok {
t.Fatalf("response is not a createRoomResponse: %+v", resp)
}
// create a room with guest access enabled and invite Charlie
resp = createRoom(ctx, createRoomRequest{
Name: "testing",
IsDirect: true,
Topic: "testing",
Visibility: "public",
Preset: presetPublicChat,
Invite: []string{charlie.ID},
GuestCanJoin: true,
}, aliceDev, &base.Cfg.ClientAPI, userAPI, rsAPI, asAPI, time.Now())
crRespWithGuestAccess, ok := resp.JSON.(createRoomResponse)
if !ok {
t.Fatalf("response is not a createRoomResponse: %+v", resp)
}
// Dummy request
body := &bytes.Buffer{}
req, err := http.NewRequest(http.MethodPost, "/?server_name=test", body)
if err != nil {
t.Fatal(err)
}
testCases := []struct {
name string
device *uapi.Device
roomID string
wantHTTP200 bool
}{
{
name: "User can join successfully by alias",
device: bobDev,
roomID: crResp.RoomAlias,
wantHTTP200: true,
},
{
name: "User can join successfully by roomID",
device: bobDev,
roomID: crResp.RoomID,
wantHTTP200: true,
},
{
name: "join is forbidden if user is guest",
device: charlieDev,
roomID: crResp.RoomID,
},
{
name: "room does not exist",
device: aliceDev,
roomID: "!doesnotexist:test",
},
{
name: "user from different server",
device: &uapi.Device{UserID: "@wrong:server"},
roomID: crResp.RoomAlias,
},
{
name: "user doesn't exist locally",
device: &uapi.Device{UserID: "@doesnotexist:test"},
roomID: crResp.RoomAlias,
},
{
name: "invalid room ID",
device: aliceDev,
roomID: "invalidRoomID",
},
{
name: "roomAlias does not exist",
device: aliceDev,
roomID: "#doesnotexist:test",
},
{
name: "room with guest_access event",
device: charlieDev,
roomID: crRespWithGuestAccess.RoomID,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
joinResp := JoinRoomByIDOrAlias(req, tc.device, rsAPI, userAPI, tc.roomID)
if tc.wantHTTP200 && !joinResp.Is2xx() {
t.Fatalf("expected join room to succeed, but didn't: %+v", joinResp)
}
})
}
})
}

View file

@ -7,6 +7,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -81,7 +82,7 @@ func Password(
sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword) sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword)
// Check the new password strength. // Check the new password strength.
if resErr = validatePassword(r.NewPassword); resErr != nil { if resErr = internal.ValidatePassword(r.NewPassword); resErr != nil {
return *resErr return *resErr
} }

View file

@ -30,6 +30,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/matrix-org/dendrite/internal"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/eventutil"
@ -60,8 +61,6 @@ var (
) )
const ( const (
minPasswordLength = 8 // http://matrix.org/docs/spec/client_server/r0.2.0.html#password-based
maxPasswordLength = 512 // https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161
maxUsernameLength = 254 // http://matrix.org/speculator/spec/HEAD/intro.html#user-identifiers TODO account for domain maxUsernameLength = 254 // http://matrix.org/speculator/spec/HEAD/intro.html#user-identifiers TODO account for domain
sessionIDLength = 24 sessionIDLength = 24
) )
@ -315,23 +314,6 @@ func validateApplicationServiceUsername(localpart string, domain gomatrixserverl
return nil return nil
} }
// validatePassword returns an error response if the password is invalid
func validatePassword(password string) *util.JSONResponse {
// https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161
if len(password) > maxPasswordLength {
return &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON(fmt.Sprintf("'password' >%d characters", maxPasswordLength)),
}
} else if len(password) > 0 && len(password) < minPasswordLength {
return &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.WeakPassword(fmt.Sprintf("password too weak: min %d chars", minPasswordLength)),
}
}
return nil
}
// validateRecaptcha returns an error response if the captcha response is invalid // validateRecaptcha returns an error response if the captcha response is invalid
func validateRecaptcha( func validateRecaptcha(
cfg *config.ClientAPI, cfg *config.ClientAPI,
@ -636,7 +618,7 @@ func Register(
return *resErr return *resErr
} }
} }
if resErr := validatePassword(r.Password); resErr != nil { if resErr := internal.ValidatePassword(r.Password); resErr != nil {
return *resErr return *resErr
} }
@ -1138,7 +1120,7 @@ func handleSharedSecretRegistration(cfg *config.ClientAPI, userAPI userapi.Clien
if resErr := validateUsername(ssrr.User, cfg.Matrix.ServerName); resErr != nil { if resErr := validateUsername(ssrr.User, cfg.Matrix.ServerName); resErr != nil {
return *resErr return *resErr
} }
if resErr := validatePassword(ssrr.Password); resErr != nil { if resErr := internal.ValidatePassword(ssrr.Password); resErr != nil {
return *resErr return *resErr
} }
deviceID := "shared_secret_registration" deviceID := "shared_secret_registration"

View file

@ -20,6 +20,7 @@ import (
"strings" "strings"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/nats-io/nats.go" "github.com/nats-io/nats.go"
@ -49,7 +50,7 @@ import (
// applied: // applied:
// nolint: gocyclo // nolint: gocyclo
func Setup( func Setup(
publicAPIMux, wkMux, synapseAdminRouter, dendriteAdminRouter *mux.Router, base *base.BaseDendrite,
cfg *config.ClientAPI, cfg *config.ClientAPI,
rsAPI roomserverAPI.ClientRoomserverAPI, rsAPI roomserverAPI.ClientRoomserverAPI,
asAPI appserviceAPI.AppServiceInternalAPI, asAPI appserviceAPI.AppServiceInternalAPI,
@ -63,7 +64,14 @@ func Setup(
extRoomsProvider api.ExtraPublicRoomsProvider, extRoomsProvider api.ExtraPublicRoomsProvider,
mscCfg *config.MSCs, natsClient *nats.Conn, mscCfg *config.MSCs, natsClient *nats.Conn,
) { ) {
prometheus.MustRegister(amtRegUsers, sendEventDuration) publicAPIMux := base.PublicClientAPIMux
wkMux := base.PublicWellKnownAPIMux
synapseAdminRouter := base.SynapseAdminMux
dendriteAdminRouter := base.DendriteAdminMux
if base.EnableMetrics {
prometheus.MustRegister(amtRegUsers, sendEventDuration)
}
rateLimits := httputil.NewRateLimits(&cfg.RateLimiting) rateLimits := httputil.NewRateLimits(&cfg.RateLimiting)
userInteractiveAuth := auth.NewUserInteractive(userAPI, cfg) userInteractiveAuth := auth.NewUserInteractive(userAPI, cfg)
@ -637,7 +645,7 @@ func Setup(
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
v3mux.Handle("/auth/{authType}/fallback/web", v3mux.Handle("/auth/{authType}/fallback/web",
httputil.MakeHTMLAPI("auth_fallback", func(w http.ResponseWriter, req *http.Request) *util.JSONResponse { httputil.MakeHTMLAPI("auth_fallback", base.EnableMetrics, func(w http.ResponseWriter, req *http.Request) *util.JSONResponse {
vars := mux.Vars(req) vars := mux.Vars(req)
return AuthFallback(w, req, vars["authType"], cfg) return AuthFallback(w, req, vars["authType"], cfg)
}), }),

View file

@ -44,7 +44,9 @@ This endpoint will instruct Dendrite to part the given local `userID` in the URL
all rooms which they are currently joined. A JSON body will be returned containing all rooms which they are currently joined. A JSON body will be returned containing
the room IDs of all affected rooms. the room IDs of all affected rooms.
## POST `/_dendrite/admin/resetPassword/{localpart}` ## POST `/_dendrite/admin/resetPassword/{userID}`
Reset the password of a local user.
Request body format: Request body format:
@ -54,9 +56,6 @@ Request body format:
} }
``` ```
Reset the password of a local user. The `localpart` is the username only, i.e. if
the full user ID is `@alice:domain.com` then the local part is `alice`.
## GET `/_dendrite/admin/fulltext/reindex` ## GET `/_dendrite/admin/fulltext/reindex`
This endpoint instructs Dendrite to reindex all searchable events (`m.room.message`, `m.room.topic` and `m.room.name`). An empty JSON body will be returned immediately. This endpoint instructs Dendrite to reindex all searchable events (`m.room.message`, `m.room.topic` and `m.room.name`). An empty JSON body will be returned immediately.

View file

@ -221,28 +221,6 @@ func (d *fakeDatabase) CleanEDUs(ctx context.Context, serverName gomatrixserverl
return nil return nil
} }
func (d *fakeDatabase) GetPendingPDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error) {
d.dbMutex.Lock()
defer d.dbMutex.Unlock()
var count int64
if pdus, ok := d.associatedPDUs[serverName]; ok {
count = int64(len(pdus))
}
return count, nil
}
func (d *fakeDatabase) GetPendingEDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error) {
d.dbMutex.Lock()
defer d.dbMutex.Unlock()
var count int64
if edus, ok := d.associatedEDUs[serverName]; ok {
count = int64(len(edus))
}
return count, nil
}
func (d *fakeDatabase) GetPendingPDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) { func (d *fakeDatabase) GetPendingPDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) {
d.dbMutex.Lock() d.dbMutex.Lock()
defer d.dbMutex.Unlock() defer d.dbMutex.Unlock()

View file

@ -45,9 +45,6 @@ type Database interface {
CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error
CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error
GetPendingPDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error)
GetPendingEDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error)
GetPendingPDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) GetPendingPDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
GetPendingEDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) GetPendingEDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error)

View file

@ -44,13 +44,13 @@ const selectInboundPeekSQL = "" +
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3" "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3"
const selectInboundPeeksSQL = "" + const selectInboundPeeksSQL = "" +
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1" "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1 ORDER by creation_ts"
const renewInboundPeekSQL = "" + const renewInboundPeekSQL = "" +
"UPDATE federationsender_inbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5" "UPDATE federationsender_inbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5"
const deleteInboundPeekSQL = "" + const deleteInboundPeekSQL = "" +
"DELETE FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2" "DELETE FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3"
const deleteInboundPeeksSQL = "" + const deleteInboundPeeksSQL = "" +
"DELETE FROM federationsender_inbound_peeks WHERE room_id = $1" "DELETE FROM federationsender_inbound_peeks WHERE room_id = $1"
@ -74,25 +74,15 @@ func NewPostgresInboundPeeksTable(db *sql.DB) (s *inboundPeeksStatements, err er
return return
} }
if s.insertInboundPeekStmt, err = db.Prepare(insertInboundPeekSQL); err != nil { return s, sqlutil.StatementList{
return {&s.insertInboundPeekStmt, insertInboundPeekSQL},
} {&s.selectInboundPeekStmt, selectInboundPeekSQL},
if s.selectInboundPeekStmt, err = db.Prepare(selectInboundPeekSQL); err != nil { {&s.selectInboundPeekStmt, selectInboundPeekSQL},
return {&s.selectInboundPeeksStmt, selectInboundPeeksSQL},
} {&s.renewInboundPeekStmt, renewInboundPeekSQL},
if s.selectInboundPeeksStmt, err = db.Prepare(selectInboundPeeksSQL); err != nil { {&s.deleteInboundPeeksStmt, deleteInboundPeeksSQL},
return {&s.deleteInboundPeekStmt, deleteInboundPeekSQL},
} }.Prepare(db)
if s.renewInboundPeekStmt, err = db.Prepare(renewInboundPeekSQL); err != nil {
return
}
if s.deleteInboundPeeksStmt, err = db.Prepare(deleteInboundPeeksSQL); err != nil {
return
}
if s.deleteInboundPeekStmt, err = db.Prepare(deleteInboundPeekSQL); err != nil {
return
}
return
} }
func (s *inboundPeeksStatements) InsertInboundPeek( func (s *inboundPeeksStatements) InsertInboundPeek(

View file

@ -44,13 +44,13 @@ const selectOutboundPeekSQL = "" +
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3" "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3"
const selectOutboundPeeksSQL = "" + const selectOutboundPeeksSQL = "" +
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1" "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1 ORDER BY creation_ts"
const renewOutboundPeekSQL = "" + const renewOutboundPeekSQL = "" +
"UPDATE federationsender_outbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5" "UPDATE federationsender_outbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5"
const deleteOutboundPeekSQL = "" + const deleteOutboundPeekSQL = "" +
"DELETE FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2" "DELETE FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3"
const deleteOutboundPeeksSQL = "" + const deleteOutboundPeeksSQL = "" +
"DELETE FROM federationsender_outbound_peeks WHERE room_id = $1" "DELETE FROM federationsender_outbound_peeks WHERE room_id = $1"
@ -74,25 +74,14 @@ func NewPostgresOutboundPeeksTable(db *sql.DB) (s *outboundPeeksStatements, err
return return
} }
if s.insertOutboundPeekStmt, err = db.Prepare(insertOutboundPeekSQL); err != nil { return s, sqlutil.StatementList{
return {&s.insertOutboundPeekStmt, insertOutboundPeekSQL},
} {&s.selectOutboundPeekStmt, selectOutboundPeekSQL},
if s.selectOutboundPeekStmt, err = db.Prepare(selectOutboundPeekSQL); err != nil { {&s.selectOutboundPeeksStmt, selectOutboundPeeksSQL},
return {&s.renewOutboundPeekStmt, renewOutboundPeekSQL},
} {&s.deleteOutboundPeeksStmt, deleteOutboundPeeksSQL},
if s.selectOutboundPeeksStmt, err = db.Prepare(selectOutboundPeeksSQL); err != nil { {&s.deleteOutboundPeekStmt, deleteOutboundPeekSQL},
return }.Prepare(db)
}
if s.renewOutboundPeekStmt, err = db.Prepare(renewOutboundPeekSQL); err != nil {
return
}
if s.deleteOutboundPeeksStmt, err = db.Prepare(deleteOutboundPeeksSQL); err != nil {
return
}
if s.deleteOutboundPeekStmt, err = db.Prepare(deleteOutboundPeekSQL); err != nil {
return
}
return
} }
func (s *outboundPeeksStatements) InsertOutboundPeek( func (s *outboundPeeksStatements) InsertOutboundPeek(

View file

@ -62,10 +62,6 @@ const selectQueueEDUReferenceJSONCountSQL = "" +
"SELECT COUNT(*) FROM federationsender_queue_edus" + "SELECT COUNT(*) FROM federationsender_queue_edus" +
" WHERE json_nid = $1" " WHERE json_nid = $1"
const selectQueueEDUCountSQL = "" +
"SELECT COUNT(*) FROM federationsender_queue_edus" +
" WHERE server_name = $1"
const selectQueueServerNamesSQL = "" + const selectQueueServerNamesSQL = "" +
"SELECT DISTINCT server_name FROM federationsender_queue_edus" "SELECT DISTINCT server_name FROM federationsender_queue_edus"
@ -81,7 +77,6 @@ type queueEDUsStatements struct {
deleteQueueEDUStmt *sql.Stmt deleteQueueEDUStmt *sql.Stmt
selectQueueEDUStmt *sql.Stmt selectQueueEDUStmt *sql.Stmt
selectQueueEDUReferenceJSONCountStmt *sql.Stmt selectQueueEDUReferenceJSONCountStmt *sql.Stmt
selectQueueEDUCountStmt *sql.Stmt
selectQueueEDUServerNamesStmt *sql.Stmt selectQueueEDUServerNamesStmt *sql.Stmt
selectExpiredEDUsStmt *sql.Stmt selectExpiredEDUsStmt *sql.Stmt
deleteExpiredEDUsStmt *sql.Stmt deleteExpiredEDUsStmt *sql.Stmt
@ -116,7 +111,6 @@ func (s *queueEDUsStatements) Prepare() error {
{&s.deleteQueueEDUStmt, deleteQueueEDUSQL}, {&s.deleteQueueEDUStmt, deleteQueueEDUSQL},
{&s.selectQueueEDUStmt, selectQueueEDUSQL}, {&s.selectQueueEDUStmt, selectQueueEDUSQL},
{&s.selectQueueEDUReferenceJSONCountStmt, selectQueueEDUReferenceJSONCountSQL}, {&s.selectQueueEDUReferenceJSONCountStmt, selectQueueEDUReferenceJSONCountSQL},
{&s.selectQueueEDUCountStmt, selectQueueEDUCountSQL},
{&s.selectQueueEDUServerNamesStmt, selectQueueServerNamesSQL}, {&s.selectQueueEDUServerNamesStmt, selectQueueServerNamesSQL},
{&s.selectExpiredEDUsStmt, selectExpiredEDUsSQL}, {&s.selectExpiredEDUsStmt, selectExpiredEDUsSQL},
{&s.deleteExpiredEDUsStmt, deleteExpiredEDUsSQL}, {&s.deleteExpiredEDUsStmt, deleteExpiredEDUsSQL},
@ -186,21 +180,6 @@ func (s *queueEDUsStatements) SelectQueueEDUReferenceJSONCount(
return count, err return count, err
} }
func (s *queueEDUsStatements) SelectQueueEDUCount(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) (int64, error) {
var count int64
stmt := sqlutil.TxStmt(txn, s.selectQueueEDUCountStmt)
err := stmt.QueryRowContext(ctx, serverName).Scan(&count)
if err == sql.ErrNoRows {
// It's acceptable for there to be no rows referencing a given
// JSON NID but it's not an error condition. Just return as if
// there's a zero count.
return 0, nil
}
return count, err
}
func (s *queueEDUsStatements) SelectQueueEDUServerNames( func (s *queueEDUsStatements) SelectQueueEDUServerNames(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
) ([]gomatrixserverlib.ServerName, error) { ) ([]gomatrixserverlib.ServerName, error) {

View file

@ -58,10 +58,6 @@ const selectQueuePDUReferenceJSONCountSQL = "" +
"SELECT COUNT(*) FROM federationsender_queue_pdus" + "SELECT COUNT(*) FROM federationsender_queue_pdus" +
" WHERE json_nid = $1" " WHERE json_nid = $1"
const selectQueuePDUsCountSQL = "" +
"SELECT COUNT(*) FROM federationsender_queue_pdus" +
" WHERE server_name = $1"
const selectQueuePDUServerNamesSQL = "" + const selectQueuePDUServerNamesSQL = "" +
"SELECT DISTINCT server_name FROM federationsender_queue_pdus" "SELECT DISTINCT server_name FROM federationsender_queue_pdus"
@ -71,7 +67,6 @@ type queuePDUsStatements struct {
deleteQueuePDUsStmt *sql.Stmt deleteQueuePDUsStmt *sql.Stmt
selectQueuePDUsStmt *sql.Stmt selectQueuePDUsStmt *sql.Stmt
selectQueuePDUReferenceJSONCountStmt *sql.Stmt selectQueuePDUReferenceJSONCountStmt *sql.Stmt
selectQueuePDUsCountStmt *sql.Stmt
selectQueuePDUServerNamesStmt *sql.Stmt selectQueuePDUServerNamesStmt *sql.Stmt
} }
@ -95,9 +90,6 @@ func NewPostgresQueuePDUsTable(db *sql.DB) (s *queuePDUsStatements, err error) {
if s.selectQueuePDUReferenceJSONCountStmt, err = s.db.Prepare(selectQueuePDUReferenceJSONCountSQL); err != nil { if s.selectQueuePDUReferenceJSONCountStmt, err = s.db.Prepare(selectQueuePDUReferenceJSONCountSQL); err != nil {
return return
} }
if s.selectQueuePDUsCountStmt, err = s.db.Prepare(selectQueuePDUsCountSQL); err != nil {
return
}
if s.selectQueuePDUServerNamesStmt, err = s.db.Prepare(selectQueuePDUServerNamesSQL); err != nil { if s.selectQueuePDUServerNamesStmt, err = s.db.Prepare(selectQueuePDUServerNamesSQL); err != nil {
return return
} }
@ -146,21 +138,6 @@ func (s *queuePDUsStatements) SelectQueuePDUReferenceJSONCount(
return count, err return count, err
} }
func (s *queuePDUsStatements) SelectQueuePDUCount(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) (int64, error) {
var count int64
stmt := sqlutil.TxStmt(txn, s.selectQueuePDUsCountStmt)
err := stmt.QueryRowContext(ctx, serverName).Scan(&count)
if err == sql.ErrNoRows {
// It's acceptable for there to be no rows referencing a given
// JSON NID but it's not an error condition. Just return as if
// there's a zero count.
return 0, nil
}
return count, err
}
func (s *queuePDUsStatements) SelectQueuePDUs( func (s *queuePDUsStatements) SelectQueuePDUs(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
serverName gomatrixserverlib.ServerName, serverName gomatrixserverlib.ServerName,

View file

@ -162,15 +162,6 @@ func (d *Database) CleanEDUs(
}) })
} }
// GetPendingEDUCount returns the number of EDUs waiting to be
// sent for a given servername.
func (d *Database) GetPendingEDUCount(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
) (int64, error) {
return d.FederationQueueEDUs.SelectQueueEDUCount(ctx, nil, serverName)
}
// GetPendingServerNames returns the server names that have EDUs // GetPendingServerNames returns the server names that have EDUs
// waiting to be sent. // waiting to be sent.
func (d *Database) GetPendingEDUServerNames( func (d *Database) GetPendingEDUServerNames(

View file

@ -141,15 +141,6 @@ func (d *Database) CleanPDUs(
}) })
} }
// GetPendingPDUCount returns the number of PDUs waiting to be
// sent for a given servername.
func (d *Database) GetPendingPDUCount(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
) (int64, error) {
return d.FederationQueuePDUs.SelectQueuePDUCount(ctx, nil, serverName)
}
// GetPendingServerNames returns the server names that have PDUs // GetPendingServerNames returns the server names that have PDUs
// waiting to be sent. // waiting to be sent.
func (d *Database) GetPendingPDUServerNames( func (d *Database) GetPendingPDUServerNames(

View file

@ -44,13 +44,13 @@ const selectInboundPeekSQL = "" +
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3" "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3"
const selectInboundPeeksSQL = "" + const selectInboundPeeksSQL = "" +
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1" "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1 ORDER BY creation_ts"
const renewInboundPeekSQL = "" + const renewInboundPeekSQL = "" +
"UPDATE federationsender_inbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5" "UPDATE federationsender_inbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5"
const deleteInboundPeekSQL = "" + const deleteInboundPeekSQL = "" +
"DELETE FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2" "DELETE FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3"
const deleteInboundPeeksSQL = "" + const deleteInboundPeeksSQL = "" +
"DELETE FROM federationsender_inbound_peeks WHERE room_id = $1" "DELETE FROM federationsender_inbound_peeks WHERE room_id = $1"
@ -74,25 +74,15 @@ func NewSQLiteInboundPeeksTable(db *sql.DB) (s *inboundPeeksStatements, err erro
return return
} }
if s.insertInboundPeekStmt, err = db.Prepare(insertInboundPeekSQL); err != nil { return s, sqlutil.StatementList{
return {&s.insertInboundPeekStmt, insertInboundPeekSQL},
} {&s.selectInboundPeekStmt, selectInboundPeekSQL},
if s.selectInboundPeekStmt, err = db.Prepare(selectInboundPeekSQL); err != nil { {&s.selectInboundPeekStmt, selectInboundPeekSQL},
return {&s.selectInboundPeeksStmt, selectInboundPeeksSQL},
} {&s.renewInboundPeekStmt, renewInboundPeekSQL},
if s.selectInboundPeeksStmt, err = db.Prepare(selectInboundPeeksSQL); err != nil { {&s.deleteInboundPeeksStmt, deleteInboundPeeksSQL},
return {&s.deleteInboundPeekStmt, deleteInboundPeekSQL},
} }.Prepare(db)
if s.renewInboundPeekStmt, err = db.Prepare(renewInboundPeekSQL); err != nil {
return
}
if s.deleteInboundPeeksStmt, err = db.Prepare(deleteInboundPeeksSQL); err != nil {
return
}
if s.deleteInboundPeekStmt, err = db.Prepare(deleteInboundPeekSQL); err != nil {
return
}
return
} }
func (s *inboundPeeksStatements) InsertInboundPeek( func (s *inboundPeeksStatements) InsertInboundPeek(

View file

@ -44,13 +44,13 @@ const selectOutboundPeekSQL = "" +
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3" "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3"
const selectOutboundPeeksSQL = "" + const selectOutboundPeeksSQL = "" +
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1" "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1 ORDER BY creation_ts"
const renewOutboundPeekSQL = "" + const renewOutboundPeekSQL = "" +
"UPDATE federationsender_outbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5" "UPDATE federationsender_outbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5"
const deleteOutboundPeekSQL = "" + const deleteOutboundPeekSQL = "" +
"DELETE FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2" "DELETE FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3"
const deleteOutboundPeeksSQL = "" + const deleteOutboundPeeksSQL = "" +
"DELETE FROM federationsender_outbound_peeks WHERE room_id = $1" "DELETE FROM federationsender_outbound_peeks WHERE room_id = $1"
@ -74,25 +74,14 @@ func NewSQLiteOutboundPeeksTable(db *sql.DB) (s *outboundPeeksStatements, err er
return return
} }
if s.insertOutboundPeekStmt, err = db.Prepare(insertOutboundPeekSQL); err != nil { return s, sqlutil.StatementList{
return {&s.insertOutboundPeekStmt, insertOutboundPeekSQL},
} {&s.selectOutboundPeekStmt, selectOutboundPeekSQL},
if s.selectOutboundPeekStmt, err = db.Prepare(selectOutboundPeekSQL); err != nil { {&s.selectOutboundPeeksStmt, selectOutboundPeeksSQL},
return {&s.renewOutboundPeekStmt, renewOutboundPeekSQL},
} {&s.deleteOutboundPeeksStmt, deleteOutboundPeeksSQL},
if s.selectOutboundPeeksStmt, err = db.Prepare(selectOutboundPeeksSQL); err != nil { {&s.deleteOutboundPeekStmt, deleteOutboundPeekSQL},
return }.Prepare(db)
}
if s.renewOutboundPeekStmt, err = db.Prepare(renewOutboundPeekSQL); err != nil {
return
}
if s.deleteOutboundPeeksStmt, err = db.Prepare(deleteOutboundPeeksSQL); err != nil {
return
}
if s.deleteOutboundPeekStmt, err = db.Prepare(deleteOutboundPeekSQL); err != nil {
return
}
return
} }
func (s *outboundPeeksStatements) InsertOutboundPeek( func (s *outboundPeeksStatements) InsertOutboundPeek(

View file

@ -63,10 +63,6 @@ const selectQueueEDUReferenceJSONCountSQL = "" +
"SELECT COUNT(*) FROM federationsender_queue_edus" + "SELECT COUNT(*) FROM federationsender_queue_edus" +
" WHERE json_nid = $1" " WHERE json_nid = $1"
const selectQueueEDUCountSQL = "" +
"SELECT COUNT(*) FROM federationsender_queue_edus" +
" WHERE server_name = $1"
const selectQueueServerNamesSQL = "" + const selectQueueServerNamesSQL = "" +
"SELECT DISTINCT server_name FROM federationsender_queue_edus" "SELECT DISTINCT server_name FROM federationsender_queue_edus"
@ -82,7 +78,6 @@ type queueEDUsStatements struct {
// deleteQueueEDUStmt *sql.Stmt - prepared at runtime due to variadic // deleteQueueEDUStmt *sql.Stmt - prepared at runtime due to variadic
selectQueueEDUStmt *sql.Stmt selectQueueEDUStmt *sql.Stmt
selectQueueEDUReferenceJSONCountStmt *sql.Stmt selectQueueEDUReferenceJSONCountStmt *sql.Stmt
selectQueueEDUCountStmt *sql.Stmt
selectQueueEDUServerNamesStmt *sql.Stmt selectQueueEDUServerNamesStmt *sql.Stmt
selectExpiredEDUsStmt *sql.Stmt selectExpiredEDUsStmt *sql.Stmt
deleteExpiredEDUsStmt *sql.Stmt deleteExpiredEDUsStmt *sql.Stmt
@ -116,7 +111,6 @@ func (s *queueEDUsStatements) Prepare() error {
{&s.insertQueueEDUStmt, insertQueueEDUSQL}, {&s.insertQueueEDUStmt, insertQueueEDUSQL},
{&s.selectQueueEDUStmt, selectQueueEDUSQL}, {&s.selectQueueEDUStmt, selectQueueEDUSQL},
{&s.selectQueueEDUReferenceJSONCountStmt, selectQueueEDUReferenceJSONCountSQL}, {&s.selectQueueEDUReferenceJSONCountStmt, selectQueueEDUReferenceJSONCountSQL},
{&s.selectQueueEDUCountStmt, selectQueueEDUCountSQL},
{&s.selectQueueEDUServerNamesStmt, selectQueueServerNamesSQL}, {&s.selectQueueEDUServerNamesStmt, selectQueueServerNamesSQL},
{&s.selectExpiredEDUsStmt, selectExpiredEDUsSQL}, {&s.selectExpiredEDUsStmt, selectExpiredEDUsSQL},
{&s.deleteExpiredEDUsStmt, deleteExpiredEDUsSQL}, {&s.deleteExpiredEDUsStmt, deleteExpiredEDUsSQL},
@ -198,21 +192,6 @@ func (s *queueEDUsStatements) SelectQueueEDUReferenceJSONCount(
return count, err return count, err
} }
func (s *queueEDUsStatements) SelectQueueEDUCount(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) (int64, error) {
var count int64
stmt := sqlutil.TxStmt(txn, s.selectQueueEDUCountStmt)
err := stmt.QueryRowContext(ctx, serverName).Scan(&count)
if err == sql.ErrNoRows {
// It's acceptable for there to be no rows referencing a given
// JSON NID but it's not an error condition. Just return as if
// there's a zero count.
return 0, nil
}
return count, err
}
func (s *queueEDUsStatements) SelectQueueEDUServerNames( func (s *queueEDUsStatements) SelectQueueEDUServerNames(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
) ([]gomatrixserverlib.ServerName, error) { ) ([]gomatrixserverlib.ServerName, error) {

View file

@ -66,10 +66,6 @@ const selectQueuePDUsReferenceJSONCountSQL = "" +
"SELECT COUNT(*) FROM federationsender_queue_pdus" + "SELECT COUNT(*) FROM federationsender_queue_pdus" +
" WHERE json_nid = $1" " WHERE json_nid = $1"
const selectQueuePDUsCountSQL = "" +
"SELECT COUNT(*) FROM federationsender_queue_pdus" +
" WHERE server_name = $1"
const selectQueuePDUsServerNamesSQL = "" + const selectQueuePDUsServerNamesSQL = "" +
"SELECT DISTINCT server_name FROM federationsender_queue_pdus" "SELECT DISTINCT server_name FROM federationsender_queue_pdus"
@ -79,7 +75,6 @@ type queuePDUsStatements struct {
selectQueueNextTransactionIDStmt *sql.Stmt selectQueueNextTransactionIDStmt *sql.Stmt
selectQueuePDUsStmt *sql.Stmt selectQueuePDUsStmt *sql.Stmt
selectQueueReferenceJSONCountStmt *sql.Stmt selectQueueReferenceJSONCountStmt *sql.Stmt
selectQueuePDUsCountStmt *sql.Stmt
selectQueueServerNamesStmt *sql.Stmt selectQueueServerNamesStmt *sql.Stmt
// deleteQueuePDUsStmt *sql.Stmt - prepared at runtime due to variadic // deleteQueuePDUsStmt *sql.Stmt - prepared at runtime due to variadic
} }
@ -107,9 +102,6 @@ func NewSQLiteQueuePDUsTable(db *sql.DB) (s *queuePDUsStatements, err error) {
if s.selectQueueReferenceJSONCountStmt, err = db.Prepare(selectQueuePDUsReferenceJSONCountSQL); err != nil { if s.selectQueueReferenceJSONCountStmt, err = db.Prepare(selectQueuePDUsReferenceJSONCountSQL); err != nil {
return return
} }
if s.selectQueuePDUsCountStmt, err = db.Prepare(selectQueuePDUsCountSQL); err != nil {
return
}
if s.selectQueueServerNamesStmt, err = db.Prepare(selectQueuePDUsServerNamesSQL); err != nil { if s.selectQueueServerNamesStmt, err = db.Prepare(selectQueuePDUsServerNamesSQL); err != nil {
return return
} }
@ -179,21 +171,6 @@ func (s *queuePDUsStatements) SelectQueuePDUReferenceJSONCount(
return count, err return count, err
} }
func (s *queuePDUsStatements) SelectQueuePDUCount(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) (int64, error) {
var count int64
stmt := sqlutil.TxStmt(txn, s.selectQueuePDUsCountStmt)
err := stmt.QueryRowContext(ctx, serverName).Scan(&count)
if err == sql.ErrNoRows {
// It's acceptable for there to be no rows referencing a given
// JSON NID but it's not an error condition. Just return as if
// there's a zero count.
return 0, nil
}
return count, err
}
func (s *queuePDUsStatements) SelectQueuePDUs( func (s *queuePDUsStatements) SelectQueuePDUs(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
serverName gomatrixserverlib.ServerName, serverName gomatrixserverlib.ServerName,

View file

@ -2,10 +2,12 @@ package storage_test
import ( import (
"context" "context"
"reflect"
"testing" "testing"
"time" "time"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/matrix-org/dendrite/federationapi/storage" "github.com/matrix-org/dendrite/federationapi/storage"
@ -80,3 +82,167 @@ func TestExpireEDUs(t *testing.T) {
assert.Equal(t, 2, len(data)) assert.Equal(t, 2, len(data))
}) })
} }
func TestOutboundPeeking(t *testing.T) {
alice := test.NewUser(t)
room := test.NewRoom(t, alice)
_, serverName, _ := gomatrixserverlib.SplitID('@', alice.ID)
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, closeDB := mustCreateFederationDatabase(t, dbType)
defer closeDB()
peekID := util.RandomString(8)
var renewalInterval int64 = 1000
// Add outbound peek
if err := db.AddOutboundPeek(ctx, serverName, room.ID, peekID, renewalInterval); err != nil {
t.Fatal(err)
}
// select the newly inserted peek
outboundPeek1, err := db.GetOutboundPeek(ctx, serverName, room.ID, peekID)
if err != nil {
t.Fatal(err)
}
// Assert fields are set as expected
if outboundPeek1.PeekID != peekID {
t.Fatalf("unexpected outbound peek ID: %s, want %s", outboundPeek1.PeekID, peekID)
}
if outboundPeek1.RoomID != room.ID {
t.Fatalf("unexpected outbound peek room ID: %s, want %s", outboundPeek1.RoomID, peekID)
}
if outboundPeek1.ServerName != serverName {
t.Fatalf("unexpected outbound peek servername: %s, want %s", outboundPeek1.ServerName, serverName)
}
if outboundPeek1.RenewalInterval != renewalInterval {
t.Fatalf("unexpected outbound peek renewal interval: %d, want %d", outboundPeek1.RenewalInterval, renewalInterval)
}
// Renew the peek
if err = db.RenewOutboundPeek(ctx, serverName, room.ID, peekID, 2000); err != nil {
t.Fatal(err)
}
// verify the values changed
outboundPeek2, err := db.GetOutboundPeek(ctx, serverName, room.ID, peekID)
if err != nil {
t.Fatal(err)
}
if reflect.DeepEqual(outboundPeek1, outboundPeek2) {
t.Fatal("expected a change peek, but they are the same")
}
if outboundPeek1.ServerName != outboundPeek2.ServerName {
t.Fatalf("unexpected servername change: %s -> %s", outboundPeek1.ServerName, outboundPeek2.ServerName)
}
if outboundPeek1.RoomID != outboundPeek2.RoomID {
t.Fatalf("unexpected roomID change: %s -> %s", outboundPeek1.RoomID, outboundPeek2.RoomID)
}
// insert some peeks
peekIDs := []string{peekID}
for i := 0; i < 5; i++ {
peekID = util.RandomString(8)
if err = db.AddOutboundPeek(ctx, serverName, room.ID, peekID, 1000); err != nil {
t.Fatal(err)
}
peekIDs = append(peekIDs, peekID)
}
// Now select them
outboundPeeks, err := db.GetOutboundPeeks(ctx, room.ID)
if err != nil {
t.Fatal(err)
}
if len(outboundPeeks) != len(peekIDs) {
t.Fatalf("inserted %d peeks, selected %d", len(peekIDs), len(outboundPeeks))
}
gotPeekIDs := make([]string, 0, len(outboundPeeks))
for _, p := range outboundPeeks {
gotPeekIDs = append(gotPeekIDs, p.PeekID)
}
assert.ElementsMatch(t, gotPeekIDs, peekIDs)
})
}
func TestInboundPeeking(t *testing.T) {
alice := test.NewUser(t)
room := test.NewRoom(t, alice)
_, serverName, _ := gomatrixserverlib.SplitID('@', alice.ID)
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, closeDB := mustCreateFederationDatabase(t, dbType)
defer closeDB()
peekID := util.RandomString(8)
var renewalInterval int64 = 1000
// Add inbound peek
if err := db.AddInboundPeek(ctx, serverName, room.ID, peekID, renewalInterval); err != nil {
t.Fatal(err)
}
// select the newly inserted peek
inboundPeek1, err := db.GetInboundPeek(ctx, serverName, room.ID, peekID)
if err != nil {
t.Fatal(err)
}
// Assert fields are set as expected
if inboundPeek1.PeekID != peekID {
t.Fatalf("unexpected inbound peek ID: %s, want %s", inboundPeek1.PeekID, peekID)
}
if inboundPeek1.RoomID != room.ID {
t.Fatalf("unexpected inbound peek room ID: %s, want %s", inboundPeek1.RoomID, peekID)
}
if inboundPeek1.ServerName != serverName {
t.Fatalf("unexpected inbound peek servername: %s, want %s", inboundPeek1.ServerName, serverName)
}
if inboundPeek1.RenewalInterval != renewalInterval {
t.Fatalf("unexpected inbound peek renewal interval: %d, want %d", inboundPeek1.RenewalInterval, renewalInterval)
}
// Renew the peek
if err = db.RenewInboundPeek(ctx, serverName, room.ID, peekID, 2000); err != nil {
t.Fatal(err)
}
// verify the values changed
inboundPeek2, err := db.GetInboundPeek(ctx, serverName, room.ID, peekID)
if err != nil {
t.Fatal(err)
}
if reflect.DeepEqual(inboundPeek1, inboundPeek2) {
t.Fatal("expected a change peek, but they are the same")
}
if inboundPeek1.ServerName != inboundPeek2.ServerName {
t.Fatalf("unexpected servername change: %s -> %s", inboundPeek1.ServerName, inboundPeek2.ServerName)
}
if inboundPeek1.RoomID != inboundPeek2.RoomID {
t.Fatalf("unexpected roomID change: %s -> %s", inboundPeek1.RoomID, inboundPeek2.RoomID)
}
// insert some peeks
peekIDs := []string{peekID}
for i := 0; i < 5; i++ {
peekID = util.RandomString(8)
if err = db.AddInboundPeek(ctx, serverName, room.ID, peekID, 1000); err != nil {
t.Fatal(err)
}
peekIDs = append(peekIDs, peekID)
}
// Now select them
inboundPeeks, err := db.GetInboundPeeks(ctx, room.ID)
if err != nil {
t.Fatal(err)
}
if len(inboundPeeks) != len(peekIDs) {
t.Fatalf("inserted %d peeks, selected %d", len(peekIDs), len(inboundPeeks))
}
gotPeekIDs := make([]string, 0, len(inboundPeeks))
for _, p := range inboundPeeks {
gotPeekIDs = append(gotPeekIDs, p.PeekID)
}
assert.ElementsMatch(t, gotPeekIDs, peekIDs)
})
}

View file

@ -0,0 +1,149 @@
package tables_test
import (
"context"
"reflect"
"testing"
"github.com/matrix-org/dendrite/federationapi/storage/postgres"
"github.com/matrix-org/dendrite/federationapi/storage/sqlite3"
"github.com/matrix-org/dendrite/federationapi/storage/tables"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/stretchr/testify/assert"
)
func mustCreateInboundpeeksTable(t *testing.T, dbType test.DBType) (tables.FederationInboundPeeks, func()) {
connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := sqlutil.Open(&config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
}, sqlutil.NewExclusiveWriter())
if err != nil {
t.Fatalf("failed to open database: %s", err)
}
var tab tables.FederationInboundPeeks
switch dbType {
case test.DBTypePostgres:
tab, err = postgres.NewPostgresInboundPeeksTable(db)
case test.DBTypeSQLite:
tab, err = sqlite3.NewSQLiteInboundPeeksTable(db)
}
if err != nil {
t.Fatalf("failed to create table: %s", err)
}
return tab, close
}
func TestInboundPeeksTable(t *testing.T) {
ctx := context.Background()
alice := test.NewUser(t)
room := test.NewRoom(t, alice)
_, serverName, _ := gomatrixserverlib.SplitID('@', alice.ID)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
tab, closeDB := mustCreateInboundpeeksTable(t, dbType)
defer closeDB()
// Insert a peek
peekID := util.RandomString(8)
var renewalInterval int64 = 1000
if err := tab.InsertInboundPeek(ctx, nil, serverName, room.ID, peekID, renewalInterval); err != nil {
t.Fatal(err)
}
// select the newly inserted peek
inboundPeek1, err := tab.SelectInboundPeek(ctx, nil, serverName, room.ID, peekID)
if err != nil {
t.Fatal(err)
}
// Assert fields are set as expected
if inboundPeek1.PeekID != peekID {
t.Fatalf("unexpected inbound peek ID: %s, want %s", inboundPeek1.PeekID, peekID)
}
if inboundPeek1.RoomID != room.ID {
t.Fatalf("unexpected inbound peek room ID: %s, want %s", inboundPeek1.RoomID, peekID)
}
if inboundPeek1.ServerName != serverName {
t.Fatalf("unexpected inbound peek servername: %s, want %s", inboundPeek1.ServerName, serverName)
}
if inboundPeek1.RenewalInterval != renewalInterval {
t.Fatalf("unexpected inbound peek renewal interval: %d, want %d", inboundPeek1.RenewalInterval, renewalInterval)
}
// Renew the peek
if err = tab.RenewInboundPeek(ctx, nil, serverName, room.ID, peekID, 2000); err != nil {
t.Fatal(err)
}
// verify the values changed
inboundPeek2, err := tab.SelectInboundPeek(ctx, nil, serverName, room.ID, peekID)
if err != nil {
t.Fatal(err)
}
if reflect.DeepEqual(inboundPeek1, inboundPeek2) {
t.Fatal("expected a change peek, but they are the same")
}
if inboundPeek1.ServerName != inboundPeek2.ServerName {
t.Fatalf("unexpected servername change: %s -> %s", inboundPeek1.ServerName, inboundPeek2.ServerName)
}
if inboundPeek1.RoomID != inboundPeek2.RoomID {
t.Fatalf("unexpected roomID change: %s -> %s", inboundPeek1.RoomID, inboundPeek2.RoomID)
}
// delete the peek
if err = tab.DeleteInboundPeek(ctx, nil, serverName, room.ID, peekID); err != nil {
t.Fatal(err)
}
// There should be no peek anymore
peek, err := tab.SelectInboundPeek(ctx, nil, serverName, room.ID, peekID)
if err != nil {
t.Fatal(err)
}
if peek != nil {
t.Fatalf("got a peek which should be deleted: %+v", peek)
}
// insert some peeks
var peekIDs []string
for i := 0; i < 5; i++ {
peekID = util.RandomString(8)
if err = tab.InsertInboundPeek(ctx, nil, serverName, room.ID, peekID, 1000); err != nil {
t.Fatal(err)
}
peekIDs = append(peekIDs, peekID)
}
// Now select them
inboundPeeks, err := tab.SelectInboundPeeks(ctx, nil, room.ID)
if err != nil {
t.Fatal(err)
}
if len(inboundPeeks) != len(peekIDs) {
t.Fatalf("inserted %d peeks, selected %d", len(peekIDs), len(inboundPeeks))
}
gotPeekIDs := make([]string, 0, len(inboundPeeks))
for _, p := range inboundPeeks {
gotPeekIDs = append(gotPeekIDs, p.PeekID)
}
assert.ElementsMatch(t, gotPeekIDs, peekIDs)
// And delete them again
if err = tab.DeleteInboundPeeks(ctx, nil, room.ID); err != nil {
t.Fatal(err)
}
// they should be gone now
inboundPeeks, err = tab.SelectInboundPeeks(ctx, nil, room.ID)
if err != nil {
t.Fatal(err)
}
if len(inboundPeeks) > 0 {
t.Fatal("got inbound peeks which should be deleted")
}
})
}

View file

@ -28,7 +28,6 @@ type FederationQueuePDUs interface {
InsertQueuePDU(ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nid int64) error InsertQueuePDU(ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nid int64) error
DeleteQueuePDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error DeleteQueuePDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error
SelectQueuePDUReferenceJSONCount(ctx context.Context, txn *sql.Tx, jsonNID int64) (int64, error) SelectQueuePDUReferenceJSONCount(ctx context.Context, txn *sql.Tx, jsonNID int64) (int64, error)
SelectQueuePDUCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error)
SelectQueuePDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error) SelectQueuePDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error)
SelectQueuePDUServerNames(ctx context.Context, txn *sql.Tx) ([]gomatrixserverlib.ServerName, error) SelectQueuePDUServerNames(ctx context.Context, txn *sql.Tx) ([]gomatrixserverlib.ServerName, error)
} }
@ -38,7 +37,6 @@ type FederationQueueEDUs interface {
DeleteQueueEDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error DeleteQueueEDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error
SelectQueueEDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error) SelectQueueEDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error)
SelectQueueEDUReferenceJSONCount(ctx context.Context, txn *sql.Tx, jsonNID int64) (int64, error) SelectQueueEDUReferenceJSONCount(ctx context.Context, txn *sql.Tx, jsonNID int64) (int64, error)
SelectQueueEDUCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error)
SelectQueueEDUServerNames(ctx context.Context, txn *sql.Tx) ([]gomatrixserverlib.ServerName, error) SelectQueueEDUServerNames(ctx context.Context, txn *sql.Tx) ([]gomatrixserverlib.ServerName, error)
SelectExpiredEDUs(ctx context.Context, txn *sql.Tx, expiredBefore gomatrixserverlib.Timestamp) ([]int64, error) SelectExpiredEDUs(ctx context.Context, txn *sql.Tx, expiredBefore gomatrixserverlib.Timestamp) ([]int64, error)
DeleteExpiredEDUs(ctx context.Context, txn *sql.Tx, expiredBefore gomatrixserverlib.Timestamp) error DeleteExpiredEDUs(ctx context.Context, txn *sql.Tx, expiredBefore gomatrixserverlib.Timestamp) error

View file

@ -0,0 +1,148 @@
package tables_test
import (
"context"
"reflect"
"testing"
"github.com/matrix-org/dendrite/federationapi/storage/postgres"
"github.com/matrix-org/dendrite/federationapi/storage/sqlite3"
"github.com/matrix-org/dendrite/federationapi/storage/tables"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/stretchr/testify/assert"
)
func mustCreateOutboundpeeksTable(t *testing.T, dbType test.DBType) (tables.FederationOutboundPeeks, func()) {
connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := sqlutil.Open(&config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
}, sqlutil.NewExclusiveWriter())
if err != nil {
t.Fatalf("failed to open database: %s", err)
}
var tab tables.FederationOutboundPeeks
switch dbType {
case test.DBTypePostgres:
tab, err = postgres.NewPostgresOutboundPeeksTable(db)
case test.DBTypeSQLite:
tab, err = sqlite3.NewSQLiteOutboundPeeksTable(db)
}
if err != nil {
t.Fatalf("failed to create table: %s", err)
}
return tab, close
}
func TestOutboundPeeksTable(t *testing.T) {
ctx := context.Background()
alice := test.NewUser(t)
room := test.NewRoom(t, alice)
_, serverName, _ := gomatrixserverlib.SplitID('@', alice.ID)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
tab, closeDB := mustCreateOutboundpeeksTable(t, dbType)
defer closeDB()
// Insert a peek
peekID := util.RandomString(8)
var renewalInterval int64 = 1000
if err := tab.InsertOutboundPeek(ctx, nil, serverName, room.ID, peekID, renewalInterval); err != nil {
t.Fatal(err)
}
// select the newly inserted peek
outboundPeek1, err := tab.SelectOutboundPeek(ctx, nil, serverName, room.ID, peekID)
if err != nil {
t.Fatal(err)
}
// Assert fields are set as expected
if outboundPeek1.PeekID != peekID {
t.Fatalf("unexpected outbound peek ID: %s, want %s", outboundPeek1.PeekID, peekID)
}
if outboundPeek1.RoomID != room.ID {
t.Fatalf("unexpected outbound peek room ID: %s, want %s", outboundPeek1.RoomID, peekID)
}
if outboundPeek1.ServerName != serverName {
t.Fatalf("unexpected outbound peek servername: %s, want %s", outboundPeek1.ServerName, serverName)
}
if outboundPeek1.RenewalInterval != renewalInterval {
t.Fatalf("unexpected outbound peek renewal interval: %d, want %d", outboundPeek1.RenewalInterval, renewalInterval)
}
// Renew the peek
if err = tab.RenewOutboundPeek(ctx, nil, serverName, room.ID, peekID, 2000); err != nil {
t.Fatal(err)
}
// verify the values changed
outboundPeek2, err := tab.SelectOutboundPeek(ctx, nil, serverName, room.ID, peekID)
if err != nil {
t.Fatal(err)
}
if reflect.DeepEqual(outboundPeek1, outboundPeek2) {
t.Fatal("expected a change peek, but they are the same")
}
if outboundPeek1.ServerName != outboundPeek2.ServerName {
t.Fatalf("unexpected servername change: %s -> %s", outboundPeek1.ServerName, outboundPeek2.ServerName)
}
if outboundPeek1.RoomID != outboundPeek2.RoomID {
t.Fatalf("unexpected roomID change: %s -> %s", outboundPeek1.RoomID, outboundPeek2.RoomID)
}
// delete the peek
if err = tab.DeleteOutboundPeek(ctx, nil, serverName, room.ID, peekID); err != nil {
t.Fatal(err)
}
// There should be no peek anymore
peek, err := tab.SelectOutboundPeek(ctx, nil, serverName, room.ID, peekID)
if err != nil {
t.Fatal(err)
}
if peek != nil {
t.Fatalf("got a peek which should be deleted: %+v", peek)
}
// insert some peeks
var peekIDs []string
for i := 0; i < 5; i++ {
peekID = util.RandomString(8)
if err = tab.InsertOutboundPeek(ctx, nil, serverName, room.ID, peekID, 1000); err != nil {
t.Fatal(err)
}
peekIDs = append(peekIDs, peekID)
}
// Now select them
outboundPeeks, err := tab.SelectOutboundPeeks(ctx, nil, room.ID)
if err != nil {
t.Fatal(err)
}
if len(outboundPeeks) != len(peekIDs) {
t.Fatalf("inserted %d peeks, selected %d", len(peekIDs), len(outboundPeeks))
}
gotPeekIDs := make([]string, 0, len(outboundPeeks))
for _, p := range outboundPeeks {
gotPeekIDs = append(gotPeekIDs, p.PeekID)
}
assert.ElementsMatch(t, gotPeekIDs, peekIDs)
// And delete them again
if err = tab.DeleteOutboundPeeks(ctx, nil, room.ID); err != nil {
t.Fatal(err)
}
// they should be gone now
outboundPeeks, err = tab.SelectOutboundPeeks(ctx, nil, room.ID)
if err != nil {
t.Fatal(err)
}
if len(outboundPeeks) > 0 {
t.Fatal("got outbound peeks which should be deleted")
}
})
}

View file

@ -198,7 +198,7 @@ func MakeExternalAPI(metricsName string, f func(*http.Request) util.JSONResponse
// MakeHTMLAPI adds Span metrics to the HTML Handler function // MakeHTMLAPI adds Span metrics to the HTML Handler function
// This is used to serve HTML alongside JSON error messages // This is used to serve HTML alongside JSON error messages
func MakeHTMLAPI(metricsName string, f func(http.ResponseWriter, *http.Request) *util.JSONResponse) http.Handler { func MakeHTMLAPI(metricsName string, enableMetrics bool, f func(http.ResponseWriter, *http.Request) *util.JSONResponse) http.Handler {
withSpan := func(w http.ResponseWriter, req *http.Request) { withSpan := func(w http.ResponseWriter, req *http.Request) {
span := opentracing.StartSpan(metricsName) span := opentracing.StartSpan(metricsName)
defer span.Finish() defer span.Finish()
@ -211,6 +211,10 @@ func MakeHTMLAPI(metricsName string, f func(http.ResponseWriter, *http.Request)
} }
} }
if !enableMetrics {
return http.HandlerFunc(withSpan)
}
return promhttp.InstrumentHandlerCounter( return promhttp.InstrumentHandlerCounter(
promauto.NewCounterVec( promauto.NewCounterVec(
prometheus.CounterOpts{ prometheus.CounterOpts{

44
internal/validate.go Normal file
View file

@ -0,0 +1,44 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package internal
import (
"fmt"
"net/http"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/util"
)
const minPasswordLength = 8 // http://matrix.org/docs/spec/client_server/r0.2.0.html#password-based
const maxPasswordLength = 512 // https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161
// ValidatePassword returns an error response if the password is invalid
func ValidatePassword(password string) *util.JSONResponse {
// https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161
if len(password) > maxPasswordLength {
return &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON(fmt.Sprintf("password too long: max %d characters", maxPasswordLength)),
}
} else if len(password) > 0 && len(password) < minPasswordLength {
return &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.WeakPassword(fmt.Sprintf("password too weak: min %d chars", minPasswordLength)),
}
}
return nil
}

View file

@ -78,6 +78,7 @@ const (
type PerformJoinRequest struct { type PerformJoinRequest struct {
RoomIDOrAlias string `json:"room_id_or_alias"` RoomIDOrAlias string `json:"room_id_or_alias"`
UserID string `json:"user_id"` UserID string `json:"user_id"`
IsGuest bool `json:"is_guest"`
Content map[string]interface{} `json:"content"` Content map[string]interface{} `json:"content"`
ServerNames []gomatrixserverlib.ServerName `json:"server_names"` ServerNames []gomatrixserverlib.ServerName `json:"server_names"`
Unsigned map[string]interface{} `json:"unsigned"` Unsigned map[string]interface{} `json:"unsigned"`

View file

@ -4,6 +4,10 @@ import (
"context" "context"
"github.com/getsentry/sentry-go" "github.com/getsentry/sentry-go"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
"github.com/sirupsen/logrus"
asAPI "github.com/matrix-org/dendrite/appservice/api" asAPI "github.com/matrix-org/dendrite/appservice/api"
fsAPI "github.com/matrix-org/dendrite/federationapi/api" fsAPI "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/caching"
@ -19,9 +23,6 @@ import (
"github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/setup/process"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
"github.com/sirupsen/logrus"
) )
// RoomserverInternalAPI is an implementation of api.RoomserverInternalAPI // RoomserverInternalAPI is an implementation of api.RoomserverInternalAPI
@ -104,6 +105,11 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio
r.fsAPI = fsAPI r.fsAPI = fsAPI
r.KeyRing = keyRing r.KeyRing = keyRing
identity, err := r.Cfg.Matrix.SigningIdentityFor(r.ServerName)
if err != nil {
logrus.Panic(err)
}
r.Inputer = &input.Inputer{ r.Inputer = &input.Inputer{
Cfg: &r.Base.Cfg.RoomServer, Cfg: &r.Base.Cfg.RoomServer,
Base: r.Base, Base: r.Base,
@ -114,7 +120,8 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio
JetStream: r.JetStream, JetStream: r.JetStream,
NATSClient: r.NATSClient, NATSClient: r.NATSClient,
Durable: nats.Durable(r.Durable), Durable: nats.Durable(r.Durable),
ServerName: r.Cfg.Matrix.ServerName, ServerName: r.ServerName,
SigningIdentity: identity,
FSAPI: fsAPI, FSAPI: fsAPI,
KeyRing: keyRing, KeyRing: keyRing,
ACLs: r.ServerACLs, ACLs: r.ServerACLs,
@ -135,7 +142,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio
Queryer: r.Queryer, Queryer: r.Queryer,
} }
r.Peeker = &perform.Peeker{ r.Peeker = &perform.Peeker{
ServerName: r.Cfg.Matrix.ServerName, ServerName: r.ServerName,
Cfg: r.Cfg, Cfg: r.Cfg,
DB: r.DB, DB: r.DB,
FSAPI: r.fsAPI, FSAPI: r.fsAPI,
@ -146,7 +153,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio
Inputer: r.Inputer, Inputer: r.Inputer,
} }
r.Unpeeker = &perform.Unpeeker{ r.Unpeeker = &perform.Unpeeker{
ServerName: r.Cfg.Matrix.ServerName, ServerName: r.ServerName,
Cfg: r.Cfg, Cfg: r.Cfg,
DB: r.DB, DB: r.DB,
FSAPI: r.fsAPI, FSAPI: r.fsAPI,
@ -193,6 +200,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio
func (r *RoomserverInternalAPI) SetUserAPI(userAPI userapi.RoomserverUserAPI) { func (r *RoomserverInternalAPI) SetUserAPI(userAPI userapi.RoomserverUserAPI) {
r.Leaver.UserAPI = userAPI r.Leaver.UserAPI = userAPI
r.Inputer.UserAPI = userAPI
} }
func (r *RoomserverInternalAPI) SetAppserviceAPI(asAPI asAPI.AppServiceInternalAPI) { func (r *RoomserverInternalAPI) SetAppserviceAPI(asAPI asAPI.AppServiceInternalAPI) {

View file

@ -23,6 +23,8 @@ import (
"sync" "sync"
"time" "time"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/Arceliar/phony" "github.com/Arceliar/phony"
"github.com/getsentry/sentry-go" "github.com/getsentry/sentry-go"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -79,6 +81,7 @@ type Inputer struct {
JetStream nats.JetStreamContext JetStream nats.JetStreamContext
Durable nats.SubOpt Durable nats.SubOpt
ServerName gomatrixserverlib.ServerName ServerName gomatrixserverlib.ServerName
SigningIdentity *gomatrixserverlib.SigningIdentity
FSAPI fedapi.RoomserverFederationAPI FSAPI fedapi.RoomserverFederationAPI
KeyRing gomatrixserverlib.JSONVerifier KeyRing gomatrixserverlib.JSONVerifier
ACLs *acls.ServerACLs ACLs *acls.ServerACLs
@ -87,6 +90,7 @@ type Inputer struct {
workers sync.Map // room ID -> *worker workers sync.Map // room ID -> *worker
Queryer *query.Queryer Queryer *query.Queryer
UserAPI userapi.RoomserverUserAPI
} }
// If a room consumer is inactive for a while then we will allow NATS // If a room consumer is inactive for a while then we will allow NATS

View file

@ -19,6 +19,7 @@ package input
import ( import (
"context" "context"
"database/sql" "database/sql"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"time" "time"
@ -31,6 +32,8 @@ import (
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
userAPI "github.com/matrix-org/dendrite/userapi/api"
fedapi "github.com/matrix-org/dendrite/federationapi/api" fedapi "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/eventutil"
@ -440,6 +443,13 @@ func (r *Inputer) processRoomEvent(
} }
} }
// If guest_access changed and is not can_join, kick all guest users.
if event.Type() == gomatrixserverlib.MRoomGuestAccess && gjson.GetBytes(event.Content(), "guest_access").Str != "can_join" {
if err = r.kickGuests(ctx, event, roomInfo); err != nil {
logrus.WithError(err).Error("failed to kick guest users on m.room.guest_access revocation")
}
}
// Everything was OK — the latest events updater didn't error and // Everything was OK — the latest events updater didn't error and
// we've sent output events. Finally, generate a hook call. // we've sent output events. Finally, generate a hook call.
hooks.Run(hooks.KindNewEventPersisted, headered) hooks.Run(hooks.KindNewEventPersisted, headered)
@ -729,3 +739,98 @@ func (r *Inputer) calculateAndSetState(
succeeded = true succeeded = true
return nil return nil
} }
// kickGuests kicks guests users from m.room.guest_access rooms, if guest access is now prohibited.
func (r *Inputer) kickGuests(ctx context.Context, event *gomatrixserverlib.Event, roomInfo *types.RoomInfo) error {
membershipNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, true)
if err != nil {
return err
}
memberEvents, err := r.DB.Events(ctx, membershipNIDs)
if err != nil {
return err
}
inputEvents := make([]api.InputRoomEvent, 0, len(memberEvents))
latestReq := &api.QueryLatestEventsAndStateRequest{
RoomID: event.RoomID(),
}
latestRes := &api.QueryLatestEventsAndStateResponse{}
if err = r.Queryer.QueryLatestEventsAndState(ctx, latestReq, latestRes); err != nil {
return err
}
prevEvents := latestRes.LatestEvents
for _, memberEvent := range memberEvents {
if memberEvent.StateKey() == nil {
continue
}
localpart, senderDomain, err := gomatrixserverlib.SplitID('@', *memberEvent.StateKey())
if err != nil {
continue
}
accountRes := &userAPI.QueryAccountByLocalpartResponse{}
if err = r.UserAPI.QueryAccountByLocalpart(ctx, &userAPI.QueryAccountByLocalpartRequest{
Localpart: localpart,
ServerName: senderDomain,
}, accountRes); err != nil {
return err
}
if accountRes.Account == nil {
continue
}
if accountRes.Account.AccountType != userAPI.AccountTypeGuest {
continue
}
var memberContent gomatrixserverlib.MemberContent
if err = json.Unmarshal(memberEvent.Content(), &memberContent); err != nil {
return err
}
memberContent.Membership = gomatrixserverlib.Leave
stateKey := *memberEvent.StateKey()
fledglingEvent := &gomatrixserverlib.EventBuilder{
RoomID: event.RoomID(),
Type: gomatrixserverlib.MRoomMember,
StateKey: &stateKey,
Sender: stateKey,
PrevEvents: prevEvents,
}
if fledglingEvent.Content, err = json.Marshal(memberContent); err != nil {
return err
}
eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(fledglingEvent)
if err != nil {
return err
}
event, err := eventutil.BuildEvent(ctx, fledglingEvent, r.Cfg.Matrix, r.SigningIdentity, time.Now(), &eventsNeeded, latestRes)
if err != nil {
return err
}
inputEvents = append(inputEvents, api.InputRoomEvent{
Kind: api.KindNew,
Event: event,
Origin: senderDomain,
SendAsServer: string(senderDomain),
})
prevEvents = []gomatrixserverlib.EventReference{
event.EventReference(),
}
}
inputReq := &api.InputRoomEventsRequest{
InputRoomEvents: inputEvents,
Asynchronous: true, // Needs to be async, as we otherwise create a deadlock
}
inputRes := &api.InputRoomEventsResponse{}
return r.InputRoomEvents(ctx, inputReq, inputRes)
}

View file

@ -16,6 +16,7 @@ package perform
import ( import (
"context" "context"
"database/sql"
"errors" "errors"
"fmt" "fmt"
"strings" "strings"
@ -270,6 +271,28 @@ func (r *Joiner) performJoinRoomByID(
} }
} }
// If a guest is trying to join a room, check that the room has a m.room.guest_access event
if req.IsGuest {
var guestAccessEvent *gomatrixserverlib.HeaderedEvent
guestAccess := "forbidden"
guestAccessEvent, err = r.DB.GetStateEvent(ctx, req.RoomIDOrAlias, gomatrixserverlib.MRoomGuestAccess, "")
if (err != nil && !errors.Is(err, sql.ErrNoRows)) || guestAccessEvent == nil {
logrus.WithError(err).Warn("unable to get m.room.guest_access event, defaulting to 'forbidden'")
}
if guestAccessEvent != nil {
guestAccess = gjson.GetBytes(guestAccessEvent.Content(), "guest_access").String()
}
// Servers MUST only allow guest users to join rooms if the m.room.guest_access state event
// is present on the room and has the guest_access value can_join.
if guestAccess != "can_join" {
return "", "", &rsAPI.PerformError{
Code: rsAPI.PerformErrorNotAllowed,
Msg: "Guest access is forbidden",
}
}
}
// If we should do a forced federated join then do that. // If we should do a forced federated join then do that.
var joinedVia gomatrixserverlib.ServerName var joinedVia gomatrixserverlib.ServerName
if forceFederatedJoin { if forceFederatedJoin {

View file

@ -3,23 +3,27 @@ package roomserver_test
import ( import (
"context" "context"
"net/http" "net/http"
"reflect"
"testing" "testing"
"time" "time"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/userapi"
userAPI "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/federationapi" "github.com/matrix-org/dendrite/federationapi"
"github.com/matrix-org/dendrite/keyserver" "github.com/matrix-org/dendrite/keyserver"
"github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/syncapi" "github.com/matrix-org/dendrite/syncapi"
"github.com/matrix-org/dendrite/userapi"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/roomserver"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/inthttp" "github.com/matrix-org/dendrite/roomserver/inthttp"
"github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/dendrite/test/testrig"
) )
@ -34,7 +38,28 @@ func mustCreateDatabase(t *testing.T, dbType test.DBType) (*base.BaseDendrite, s
return base, db, close return base, db, close
} }
func Test_SharedUsers(t *testing.T) { func TestUsers(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
base, close := testrig.CreateBaseDendrite(t, dbType)
defer close()
rsAPI := roomserver.NewInternalAPI(base)
// SetFederationAPI starts the room event input consumer
rsAPI.SetFederationAPI(nil, nil)
t.Run("shared users", func(t *testing.T) {
testSharedUsers(t, rsAPI)
})
t.Run("kick users", func(t *testing.T) {
usrAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, nil, rsAPI, nil)
rsAPI.SetUserAPI(usrAPI)
testKickUsers(t, rsAPI, usrAPI)
})
})
}
func testSharedUsers(t *testing.T, rsAPI api.RoomserverInternalAPI) {
alice := test.NewUser(t) alice := test.NewUser(t)
bob := test.NewUser(t) bob := test.NewUser(t)
room := test.NewRoom(t, alice, test.RoomPreset(test.PresetTrustedPrivateChat)) room := test.NewRoom(t, alice, test.RoomPreset(test.PresetTrustedPrivateChat))
@ -48,36 +73,93 @@ func Test_SharedUsers(t *testing.T) {
}, test.WithStateKey(bob.ID)) }, test.WithStateKey(bob.ID))
ctx := context.Background() ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
base, _, close := mustCreateDatabase(t, dbType)
defer close()
rsAPI := roomserver.NewInternalAPI(base) // Create the room
// SetFederationAPI starts the room event input consumer if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil {
rsAPI.SetFederationAPI(nil, nil) t.Errorf("failed to send events: %v", err)
// Create the room }
if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil {
t.Fatalf("failed to send events: %v", err) // Query the shared users for Alice, there should only be Bob.
// This is used by the SyncAPI keychange consumer.
res := &api.QuerySharedUsersResponse{}
if err := rsAPI.QuerySharedUsers(ctx, &api.QuerySharedUsersRequest{UserID: alice.ID}, res); err != nil {
t.Errorf("unable to query known users: %v", err)
}
if _, ok := res.UserIDsToCount[bob.ID]; !ok {
t.Errorf("expected to find %s in shared users, but didn't: %+v", bob.ID, res.UserIDsToCount)
}
// Also verify that we get the expected result when specifying OtherUserIDs.
// This is used by the SyncAPI when getting device list changes.
if err := rsAPI.QuerySharedUsers(ctx, &api.QuerySharedUsersRequest{UserID: alice.ID, OtherUserIDs: []string{bob.ID}}, res); err != nil {
t.Errorf("unable to query known users: %v", err)
}
if _, ok := res.UserIDsToCount[bob.ID]; !ok {
t.Errorf("expected to find %s in shared users, but didn't: %+v", bob.ID, res.UserIDsToCount)
}
}
func testKickUsers(t *testing.T, rsAPI api.RoomserverInternalAPI, usrAPI userAPI.UserInternalAPI) {
// Create users and room; Bob is going to be the guest and kicked on revocation of guest access
alice := test.NewUser(t, test.WithAccountType(userAPI.AccountTypeUser))
bob := test.NewUser(t, test.WithAccountType(userAPI.AccountTypeGuest))
room := test.NewRoom(t, alice, test.RoomPreset(test.PresetPublicChat), test.GuestsCanJoin(true))
// Join with the guest user
room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{
"membership": "join",
}, test.WithStateKey(bob.ID))
ctx := context.Background()
// Create the users in the userapi, so the RSAPI can query the account type later
for _, u := range []*test.User{alice, bob} {
localpart, serverName, _ := gomatrixserverlib.SplitID('@', u.ID)
userRes := &userAPI.PerformAccountCreationResponse{}
if err := usrAPI.PerformAccountCreation(ctx, &userAPI.PerformAccountCreationRequest{
AccountType: u.AccountType,
Localpart: localpart,
ServerName: serverName,
Password: "someRandomPassword",
}, userRes); err != nil {
t.Errorf("failed to create account: %s", err)
}
}
// Create the room in the database
if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil {
t.Errorf("failed to send events: %v", err)
}
// Get the membership events BEFORE revoking guest access
membershipRes := &api.QueryMembershipsForRoomResponse{}
if err := rsAPI.QueryMembershipsForRoom(ctx, &api.QueryMembershipsForRoomRequest{LocalOnly: true, JoinedOnly: true, RoomID: room.ID}, membershipRes); err != nil {
t.Errorf("failed to query membership for room: %s", err)
}
// revoke guest access
revokeEvent := room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomGuestAccess, map[string]string{"guest_access": "forbidden"}, test.WithStateKey(""))
if err := api.SendEvents(ctx, rsAPI, api.KindNew, []*gomatrixserverlib.HeaderedEvent{revokeEvent}, "test", "test", "test", nil, false); err != nil {
t.Errorf("failed to send events: %v", err)
}
// TODO: Even though we are sending the events sync, the "kickUsers" function is sending the events async, so we need
// to loop and wait for the events to be processed by the roomserver.
for i := 0; i <= 20; i++ {
// Get the membership events AFTER revoking guest access
membershipRes2 := &api.QueryMembershipsForRoomResponse{}
if err := rsAPI.QueryMembershipsForRoom(ctx, &api.QueryMembershipsForRoomRequest{LocalOnly: true, JoinedOnly: true, RoomID: room.ID}, membershipRes2); err != nil {
t.Errorf("failed to query membership for room: %s", err)
} }
// Query the shared users for Alice, there should only be Bob. // The membership events should NOT match, as Bob (guest user) should now be kicked from the room
// This is used by the SyncAPI keychange consumer. if !reflect.DeepEqual(membershipRes, membershipRes2) {
res := &api.QuerySharedUsersResponse{} return
if err := rsAPI.QuerySharedUsers(ctx, &api.QuerySharedUsersRequest{UserID: alice.ID}, res); err != nil {
t.Fatalf("unable to query known users: %v", err)
} }
if _, ok := res.UserIDsToCount[bob.ID]; !ok { time.Sleep(time.Millisecond * 10)
t.Fatalf("expected to find %s in shared users, but didn't: %+v", bob.ID, res.UserIDsToCount) }
}
// Also verify that we get the expected result when specifying OtherUserIDs. t.Errorf("memberships didn't change in time")
// This is used by the SyncAPI when getting device list changes.
if err := rsAPI.QuerySharedUsers(ctx, &api.QuerySharedUsersRequest{UserID: alice.ID, OtherUserIDs: []string{bob.ID}}, res); err != nil {
t.Fatalf("unable to query known users: %v", err)
}
if _, ok := res.UserIDsToCount[bob.ID]; !ok {
t.Fatalf("expected to find %s in shared users, but didn't: %+v", bob.ID, res.UserIDsToCount)
}
})
} }
func Test_QueryLeftUsers(t *testing.T) { func Test_QueryLeftUsers(t *testing.T) {

View file

@ -174,7 +174,7 @@ func (c *Global) SigningIdentityFor(serverName gomatrixserverlib.ServerName) (*g
return id, nil return id, nil
} }
} }
return nil, fmt.Errorf("no signing identity %q", serverName) return nil, fmt.Errorf("no signing identity for %q", serverName)
} }
func (c *Global) SigningIdentities() []*gomatrixserverlib.SigningIdentity { func (c *Global) SigningIdentities() []*gomatrixserverlib.SigningIdentity {

View file

@ -16,8 +16,10 @@ package config
import ( import (
"fmt" "fmt"
"reflect"
"testing" "testing"
"github.com/matrix-org/gomatrixserverlib"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
) )
@ -290,3 +292,55 @@ func TestUnmarshalDataUnit(t *testing.T) {
} }
} }
} }
func Test_SigningIdentityFor(t *testing.T) {
tests := []struct {
name string
virtualHosts []*VirtualHost
serverName gomatrixserverlib.ServerName
want *gomatrixserverlib.SigningIdentity
wantErr bool
}{
{
name: "no virtual hosts defined",
wantErr: true,
},
{
name: "no identity found",
serverName: gomatrixserverlib.ServerName("doesnotexist"),
wantErr: true,
},
{
name: "found identity",
serverName: gomatrixserverlib.ServerName("main"),
want: &gomatrixserverlib.SigningIdentity{ServerName: "main"},
},
{
name: "identity found on virtual hosts",
serverName: gomatrixserverlib.ServerName("vh2"),
virtualHosts: []*VirtualHost{
{SigningIdentity: gomatrixserverlib.SigningIdentity{ServerName: "vh1"}},
{SigningIdentity: gomatrixserverlib.SigningIdentity{ServerName: "vh2"}},
},
want: &gomatrixserverlib.SigningIdentity{ServerName: "vh2"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Global{
VirtualHosts: tt.virtualHosts,
SigningIdentity: gomatrixserverlib.SigningIdentity{
ServerName: "main",
},
}
got, err := c.SigningIdentityFor(tt.serverName)
if (err != nil) != tt.wantErr {
t.Errorf("SigningIdentityFor() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("SigningIdentityFor() got = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -48,4 +48,7 @@ If a device list update goes missing, the server resyncs on the next one
Leaves are present in non-gapped incremental syncs Leaves are present in non-gapped incremental syncs
# Below test was passing for the wrong reason, failing correctly since #2858 # Below test was passing for the wrong reason, failing correctly since #2858
New federated private chats get full presence information (SYN-115) New federated private chats get full presence information (SYN-115)
# We don't have any state to calculate m.room.guest_access when accepting invites
Guest users can accept invites to private rooms over federation

View file

@ -763,4 +763,7 @@ AS and main public room lists are separate
local user has tags copied to the new room local user has tags copied to the new room
remote user has tags copied to the new room remote user has tags copied to the new room
/upgrade moves remote aliases to the new room /upgrade moves remote aliases to the new room
Local and remote users' homeservers remove a room from their public directory on upgrade Local and remote users' homeservers remove a room from their public directory on upgrade
Guest users denied access over federation if guest access prohibited
Guest users are kicked from guest_access rooms on revocation of guest_access
Guest users are kicked from guest_access rooms on revocation of guest_access over federation

View file

@ -38,11 +38,12 @@ var (
) )
type Room struct { type Room struct {
ID string ID string
Version gomatrixserverlib.RoomVersion Version gomatrixserverlib.RoomVersion
preset Preset preset Preset
visibility gomatrixserverlib.HistoryVisibility guestCanJoin bool
creator *User visibility gomatrixserverlib.HistoryVisibility
creator *User
authEvents gomatrixserverlib.AuthEvents authEvents gomatrixserverlib.AuthEvents
currentState map[string]*gomatrixserverlib.HeaderedEvent currentState map[string]*gomatrixserverlib.HeaderedEvent
@ -120,6 +121,11 @@ func (r *Room) insertCreateEvents(t *testing.T) {
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomPowerLevels, plContent, WithStateKey("")) r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomPowerLevels, plContent, WithStateKey(""))
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomJoinRules, joinRule, WithStateKey("")) r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomJoinRules, joinRule, WithStateKey(""))
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomHistoryVisibility, hisVis, WithStateKey("")) r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomHistoryVisibility, hisVis, WithStateKey(""))
if r.guestCanJoin {
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomGuestAccess, map[string]string{
"guest_access": "can_join",
}, WithStateKey(""))
}
} }
// Create an event in this room but do not insert it. Does not modify the room in any way (depth, fwd extremities, etc) so is thread-safe. // Create an event in this room but do not insert it. Does not modify the room in any way (depth, fwd extremities, etc) so is thread-safe.
@ -268,3 +274,9 @@ func RoomVersion(ver gomatrixserverlib.RoomVersion) roomModifier {
r.Version = ver r.Version = ver
} }
} }
func GuestsCanJoin(canJoin bool) roomModifier {
return func(t *testing.T, r *Room) {
r.guestCanJoin = canJoin
}
}

View file

@ -47,7 +47,7 @@ var (
type User struct { type User struct {
ID string ID string
accountType api.AccountType AccountType api.AccountType
// key ID and private key of the server who has this user, if known. // key ID and private key of the server who has this user, if known.
keyID gomatrixserverlib.KeyID keyID gomatrixserverlib.KeyID
privKey ed25519.PrivateKey privKey ed25519.PrivateKey
@ -66,7 +66,7 @@ func WithSigningServer(srvName gomatrixserverlib.ServerName, keyID gomatrixserve
func WithAccountType(accountType api.AccountType) UserOpt { func WithAccountType(accountType api.AccountType) UserOpt {
return func(u *User) { return func(u *User) {
u.accountType = accountType u.AccountType = accountType
} }
} }

View file

@ -50,6 +50,7 @@ type KeyserverUserAPI interface {
type RoomserverUserAPI interface { type RoomserverUserAPI interface {
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) (err error)
} }
// api functions required by the media api // api functions required by the media api
@ -671,3 +672,12 @@ type PerformSaveThreePIDAssociationRequest struct {
ServerName gomatrixserverlib.ServerName ServerName gomatrixserverlib.ServerName
Medium string Medium string
} }
type QueryAccountByLocalpartRequest struct {
Localpart string
ServerName gomatrixserverlib.ServerName
}
type QueryAccountByLocalpartResponse struct {
Account *Account
}

View file

@ -204,6 +204,12 @@ func (t *UserInternalAPITrace) PerformSaveThreePIDAssociation(ctx context.Contex
return err return err
} }
func (t *UserInternalAPITrace) QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) error {
err := t.Impl.QueryAccountByLocalpart(ctx, req, res)
util.GetLogger(ctx).Infof("QueryAccountByLocalpart req=%+v res=%+v", js(req), js(res))
return err
}
func js(thing interface{}) string { func js(thing interface{}) string {
b, err := json.Marshal(thing) b, err := json.Marshal(thing)
if err != nil { if err != nil {

View file

@ -548,6 +548,11 @@ func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAc
return nil return nil
} }
func (a *UserInternalAPI) QueryAccountByLocalpart(ctx context.Context, req *api.QueryAccountByLocalpartRequest, res *api.QueryAccountByLocalpartResponse) (err error) {
res.Account, err = a.DB.GetAccountByLocalpart(ctx, req.Localpart, req.ServerName)
return
}
// Return the appservice 'device' or nil if the token is not an appservice. Returns an error if there was a problem // Return the appservice 'device' or nil if the token is not an appservice. Returns an error if there was a problem
// creating a 'device'. // creating a 'device'.
func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appServiceUserID string) (*api.Device, error) { func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appServiceUserID string) (*api.Device, error) {

View file

@ -60,6 +60,7 @@ const (
QueryAccountByPasswordPath = "/userapi/queryAccountByPassword" QueryAccountByPasswordPath = "/userapi/queryAccountByPassword"
QueryLocalpartForThreePIDPath = "/userapi/queryLocalpartForThreePID" QueryLocalpartForThreePIDPath = "/userapi/queryLocalpartForThreePID"
QueryThreePIDsForLocalpartPath = "/userapi/queryThreePIDsForLocalpart" QueryThreePIDsForLocalpartPath = "/userapi/queryThreePIDsForLocalpart"
QueryAccountByLocalpartPath = "/userapi/queryAccountType"
) )
// NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API. // NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API.
@ -440,3 +441,14 @@ func (h *httpUserInternalAPI) PerformSaveThreePIDAssociation(
h.httpClient, ctx, request, response, h.httpClient, ctx, request, response,
) )
} }
func (h *httpUserInternalAPI) QueryAccountByLocalpart(
ctx context.Context,
req *api.QueryAccountByLocalpartRequest,
res *api.QueryAccountByLocalpartResponse,
) error {
return httputil.CallInternalRPCAPI(
"QueryAccountByLocalpart", h.apiURL+QueryAccountByLocalpartPath,
h.httpClient, ctx, req, res,
)
}

View file

@ -189,4 +189,9 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI, enableMetrics
PerformSaveThreePIDAssociationPath, PerformSaveThreePIDAssociationPath,
httputil.MakeInternalRPCAPI("UserAPIPerformSaveThreePIDAssociation", enableMetrics, s.PerformSaveThreePIDAssociation), httputil.MakeInternalRPCAPI("UserAPIPerformSaveThreePIDAssociation", enableMetrics, s.PerformSaveThreePIDAssociation),
) )
internalAPIMux.Handle(
QueryAccountByLocalpartPath,
httputil.MakeInternalRPCAPI("AccountByLocalpart", enableMetrics, s.QueryAccountByLocalpart),
)
} }

View file

@ -307,3 +307,64 @@ func TestLoginToken(t *testing.T) {
}) })
}) })
} }
func TestQueryAccountByLocalpart(t *testing.T) {
alice := test.NewUser(t)
localpart, userServername, _ := gomatrixserverlib.SplitID('@', alice.ID)
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
intAPI, db, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType)
defer close()
createdAcc, err := db.CreateAccount(ctx, localpart, userServername, "", "", alice.AccountType)
if err != nil {
t.Error(err)
}
testCases := func(t *testing.T, internalAPI api.UserInternalAPI) {
// Query existing account
queryAccResp := &api.QueryAccountByLocalpartResponse{}
if err = internalAPI.QueryAccountByLocalpart(ctx, &api.QueryAccountByLocalpartRequest{
Localpart: localpart,
ServerName: userServername,
}, queryAccResp); err != nil {
t.Error(err)
}
if !reflect.DeepEqual(createdAcc, queryAccResp.Account) {
t.Fatalf("created and queried accounts don't match:\n%+v vs.\n%+v", createdAcc, queryAccResp.Account)
}
// Query non-existent account, this should result in an error
err = internalAPI.QueryAccountByLocalpart(ctx, &api.QueryAccountByLocalpartRequest{
Localpart: "doesnotexist",
ServerName: userServername,
}, queryAccResp)
if err == nil {
t.Fatalf("expected an error, but got none: %+v", queryAccResp)
}
}
t.Run("Monolith", func(t *testing.T) {
testCases(t, intAPI)
// also test tracing
testCases(t, &api.UserInternalAPITrace{Impl: intAPI})
})
t.Run("HTTP API", func(t *testing.T) {
router := mux.NewRouter().PathPrefix(httputil.InternalPathPrefix).Subrouter()
userapi.AddInternalRoutes(router, intAPI, false)
apiURL, cancel := test.ListenAndServe(t, router, false)
defer cancel()
userHTTPApi, err := inthttp.NewUserAPIClient(apiURL, &http.Client{Timeout: time.Second * 5})
if err != nil {
t.Fatalf("failed to create HTTP client: %s", err)
}
testCases(t, userHTTPApi)
})
})
}