Handle guest access [1/2?] (#2872)
Needs https://github.com/matrix-org/sytest/pull/1315, as otherwise the membership events aren't persisted yet when hitting `/state` after kicking guest users. Makes the following tests pass: ``` Guest users denied access over federation if guest access prohibited Guest users are kicked from guest_access rooms on revocation of guest_access Guest users are kicked from guest_access rooms on revocation of guest_access over federation ``` Todo (in a follow up PR): - Restrict access to CS API Endpoints as per https://spec.matrix.org/v1.4/client-server-api/#client-behaviour-14 Co-authored-by: kegsay <kegan@matrix.org>
This commit is contained in:
parent
09dff951d6
commit
5eed31fea3
|
@ -15,6 +15,8 @@
|
||||||
package clientapi
|
package clientapi
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
|
||||||
appserviceAPI "github.com/matrix-org/dendrite/appservice/api"
|
appserviceAPI "github.com/matrix-org/dendrite/appservice/api"
|
||||||
"github.com/matrix-org/dendrite/clientapi/api"
|
"github.com/matrix-org/dendrite/clientapi/api"
|
||||||
"github.com/matrix-org/dendrite/clientapi/producers"
|
"github.com/matrix-org/dendrite/clientapi/producers"
|
||||||
|
@ -26,7 +28,6 @@ import (
|
||||||
"github.com/matrix-org/dendrite/setup/base"
|
"github.com/matrix-org/dendrite/setup/base"
|
||||||
"github.com/matrix-org/dendrite/setup/jetstream"
|
"github.com/matrix-org/dendrite/setup/jetstream"
|
||||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// AddPublicRoutes sets up and registers HTTP handlers for the ClientAPI component.
|
// AddPublicRoutes sets up and registers HTTP handlers for the ClientAPI component.
|
||||||
|
|
|
@ -37,6 +37,7 @@ func JoinRoomByIDOrAlias(
|
||||||
joinReq := roomserverAPI.PerformJoinRequest{
|
joinReq := roomserverAPI.PerformJoinRequest{
|
||||||
RoomIDOrAlias: roomIDOrAlias,
|
RoomIDOrAlias: roomIDOrAlias,
|
||||||
UserID: device.UserID,
|
UserID: device.UserID,
|
||||||
|
IsGuest: device.AccountType == api.AccountTypeGuest,
|
||||||
Content: map[string]interface{}{},
|
Content: map[string]interface{}{},
|
||||||
}
|
}
|
||||||
joinRes := roomserverAPI.PerformJoinResponse{}
|
joinRes := roomserverAPI.PerformJoinResponse{}
|
||||||
|
@ -84,7 +85,14 @@ func JoinRoomByIDOrAlias(
|
||||||
if err := rsAPI.PerformJoin(req.Context(), &joinReq, &joinRes); err != nil {
|
if err := rsAPI.PerformJoin(req.Context(), &joinReq, &joinRes); err != nil {
|
||||||
done <- jsonerror.InternalAPIError(req.Context(), err)
|
done <- jsonerror.InternalAPIError(req.Context(), err)
|
||||||
} else if joinRes.Error != nil {
|
} else if joinRes.Error != nil {
|
||||||
|
if joinRes.Error.Code == roomserverAPI.PerformErrorNotAllowed && device.AccountType == api.AccountTypeGuest {
|
||||||
|
done <- util.JSONResponse{
|
||||||
|
Code: http.StatusForbidden,
|
||||||
|
JSON: jsonerror.GuestAccessForbidden(joinRes.Error.Msg),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
done <- joinRes.Error.JSONResponse()
|
done <- joinRes.Error.JSONResponse()
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
done <- util.JSONResponse{
|
done <- util.JSONResponse{
|
||||||
Code: http.StatusOK,
|
Code: http.StatusOK,
|
||||||
|
|
158
clientapi/routing/joinroom_test.go
Normal file
158
clientapi/routing/joinroom_test.go
Normal file
|
@ -0,0 +1,158 @@
|
||||||
|
package routing
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/appservice"
|
||||||
|
"github.com/matrix-org/dendrite/keyserver"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver"
|
||||||
|
"github.com/matrix-org/dendrite/test"
|
||||||
|
"github.com/matrix-org/dendrite/test/testrig"
|
||||||
|
"github.com/matrix-org/dendrite/userapi"
|
||||||
|
uapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestJoinRoomByIDOrAlias(t *testing.T) {
|
||||||
|
alice := test.NewUser(t)
|
||||||
|
bob := test.NewUser(t)
|
||||||
|
charlie := test.NewUser(t, test.WithAccountType(uapi.AccountTypeGuest))
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
base, baseClose := testrig.CreateBaseDendrite(t, dbType)
|
||||||
|
defer baseClose()
|
||||||
|
|
||||||
|
rsAPI := roomserver.NewInternalAPI(base)
|
||||||
|
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI)
|
||||||
|
userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil)
|
||||||
|
asAPI := appservice.NewInternalAPI(base, userAPI, rsAPI)
|
||||||
|
rsAPI.SetFederationAPI(nil, nil) // creates the rs.Inputer etc
|
||||||
|
|
||||||
|
// Create the users in the userapi
|
||||||
|
for _, u := range []*test.User{alice, bob, charlie} {
|
||||||
|
localpart, serverName, _ := gomatrixserverlib.SplitID('@', u.ID)
|
||||||
|
userRes := &uapi.PerformAccountCreationResponse{}
|
||||||
|
if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{
|
||||||
|
AccountType: u.AccountType,
|
||||||
|
Localpart: localpart,
|
||||||
|
ServerName: serverName,
|
||||||
|
Password: "someRandomPassword",
|
||||||
|
}, userRes); err != nil {
|
||||||
|
t.Errorf("failed to create account: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
aliceDev := &uapi.Device{UserID: alice.ID}
|
||||||
|
bobDev := &uapi.Device{UserID: bob.ID}
|
||||||
|
charlieDev := &uapi.Device{UserID: charlie.ID, AccountType: uapi.AccountTypeGuest}
|
||||||
|
|
||||||
|
// create a room with disabled guest access and invite Bob
|
||||||
|
resp := createRoom(ctx, createRoomRequest{
|
||||||
|
Name: "testing",
|
||||||
|
IsDirect: true,
|
||||||
|
Topic: "testing",
|
||||||
|
Visibility: "public",
|
||||||
|
Preset: presetPublicChat,
|
||||||
|
RoomAliasName: "alias",
|
||||||
|
Invite: []string{bob.ID},
|
||||||
|
GuestCanJoin: false,
|
||||||
|
}, aliceDev, &base.Cfg.ClientAPI, userAPI, rsAPI, asAPI, time.Now())
|
||||||
|
crResp, ok := resp.JSON.(createRoomResponse)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("response is not a createRoomResponse: %+v", resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
// create a room with guest access enabled and invite Charlie
|
||||||
|
resp = createRoom(ctx, createRoomRequest{
|
||||||
|
Name: "testing",
|
||||||
|
IsDirect: true,
|
||||||
|
Topic: "testing",
|
||||||
|
Visibility: "public",
|
||||||
|
Preset: presetPublicChat,
|
||||||
|
Invite: []string{charlie.ID},
|
||||||
|
GuestCanJoin: true,
|
||||||
|
}, aliceDev, &base.Cfg.ClientAPI, userAPI, rsAPI, asAPI, time.Now())
|
||||||
|
crRespWithGuestAccess, ok := resp.JSON.(createRoomResponse)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("response is not a createRoomResponse: %+v", resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dummy request
|
||||||
|
body := &bytes.Buffer{}
|
||||||
|
req, err := http.NewRequest(http.MethodPost, "/?server_name=test", body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
device *uapi.Device
|
||||||
|
roomID string
|
||||||
|
wantHTTP200 bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "User can join successfully by alias",
|
||||||
|
device: bobDev,
|
||||||
|
roomID: crResp.RoomAlias,
|
||||||
|
wantHTTP200: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "User can join successfully by roomID",
|
||||||
|
device: bobDev,
|
||||||
|
roomID: crResp.RoomID,
|
||||||
|
wantHTTP200: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "join is forbidden if user is guest",
|
||||||
|
device: charlieDev,
|
||||||
|
roomID: crResp.RoomID,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "room does not exist",
|
||||||
|
device: aliceDev,
|
||||||
|
roomID: "!doesnotexist:test",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "user from different server",
|
||||||
|
device: &uapi.Device{UserID: "@wrong:server"},
|
||||||
|
roomID: crResp.RoomAlias,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "user doesn't exist locally",
|
||||||
|
device: &uapi.Device{UserID: "@doesnotexist:test"},
|
||||||
|
roomID: crResp.RoomAlias,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid room ID",
|
||||||
|
device: aliceDev,
|
||||||
|
roomID: "invalidRoomID",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "roomAlias does not exist",
|
||||||
|
device: aliceDev,
|
||||||
|
roomID: "#doesnotexist:test",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "room with guest_access event",
|
||||||
|
device: charlieDev,
|
||||||
|
roomID: crRespWithGuestAccess.RoomID,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
joinResp := JoinRoomByIDOrAlias(req, tc.device, rsAPI, userAPI, tc.roomID)
|
||||||
|
if tc.wantHTTP200 && !joinResp.Is2xx() {
|
||||||
|
t.Fatalf("expected join room to succeed, but didn't: %+v", joinResp)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
|
@ -78,6 +78,7 @@ const (
|
||||||
type PerformJoinRequest struct {
|
type PerformJoinRequest struct {
|
||||||
RoomIDOrAlias string `json:"room_id_or_alias"`
|
RoomIDOrAlias string `json:"room_id_or_alias"`
|
||||||
UserID string `json:"user_id"`
|
UserID string `json:"user_id"`
|
||||||
|
IsGuest bool `json:"is_guest"`
|
||||||
Content map[string]interface{} `json:"content"`
|
Content map[string]interface{} `json:"content"`
|
||||||
ServerNames []gomatrixserverlib.ServerName `json:"server_names"`
|
ServerNames []gomatrixserverlib.ServerName `json:"server_names"`
|
||||||
Unsigned map[string]interface{} `json:"unsigned"`
|
Unsigned map[string]interface{} `json:"unsigned"`
|
||||||
|
|
|
@ -4,6 +4,10 @@ import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
"github.com/getsentry/sentry-go"
|
"github.com/getsentry/sentry-go"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/nats-io/nats.go"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
|
||||||
asAPI "github.com/matrix-org/dendrite/appservice/api"
|
asAPI "github.com/matrix-org/dendrite/appservice/api"
|
||||||
fsAPI "github.com/matrix-org/dendrite/federationapi/api"
|
fsAPI "github.com/matrix-org/dendrite/federationapi/api"
|
||||||
"github.com/matrix-org/dendrite/internal/caching"
|
"github.com/matrix-org/dendrite/internal/caching"
|
||||||
|
@ -19,9 +23,6 @@ import (
|
||||||
"github.com/matrix-org/dendrite/setup/jetstream"
|
"github.com/matrix-org/dendrite/setup/jetstream"
|
||||||
"github.com/matrix-org/dendrite/setup/process"
|
"github.com/matrix-org/dendrite/setup/process"
|
||||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
|
||||||
"github.com/nats-io/nats.go"
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// RoomserverInternalAPI is an implementation of api.RoomserverInternalAPI
|
// RoomserverInternalAPI is an implementation of api.RoomserverInternalAPI
|
||||||
|
@ -104,6 +105,11 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio
|
||||||
r.fsAPI = fsAPI
|
r.fsAPI = fsAPI
|
||||||
r.KeyRing = keyRing
|
r.KeyRing = keyRing
|
||||||
|
|
||||||
|
identity, err := r.Cfg.Matrix.SigningIdentityFor(r.ServerName)
|
||||||
|
if err != nil {
|
||||||
|
logrus.Panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
r.Inputer = &input.Inputer{
|
r.Inputer = &input.Inputer{
|
||||||
Cfg: &r.Base.Cfg.RoomServer,
|
Cfg: &r.Base.Cfg.RoomServer,
|
||||||
Base: r.Base,
|
Base: r.Base,
|
||||||
|
@ -114,7 +120,8 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio
|
||||||
JetStream: r.JetStream,
|
JetStream: r.JetStream,
|
||||||
NATSClient: r.NATSClient,
|
NATSClient: r.NATSClient,
|
||||||
Durable: nats.Durable(r.Durable),
|
Durable: nats.Durable(r.Durable),
|
||||||
ServerName: r.Cfg.Matrix.ServerName,
|
ServerName: r.ServerName,
|
||||||
|
SigningIdentity: identity,
|
||||||
FSAPI: fsAPI,
|
FSAPI: fsAPI,
|
||||||
KeyRing: keyRing,
|
KeyRing: keyRing,
|
||||||
ACLs: r.ServerACLs,
|
ACLs: r.ServerACLs,
|
||||||
|
@ -135,7 +142,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio
|
||||||
Queryer: r.Queryer,
|
Queryer: r.Queryer,
|
||||||
}
|
}
|
||||||
r.Peeker = &perform.Peeker{
|
r.Peeker = &perform.Peeker{
|
||||||
ServerName: r.Cfg.Matrix.ServerName,
|
ServerName: r.ServerName,
|
||||||
Cfg: r.Cfg,
|
Cfg: r.Cfg,
|
||||||
DB: r.DB,
|
DB: r.DB,
|
||||||
FSAPI: r.fsAPI,
|
FSAPI: r.fsAPI,
|
||||||
|
@ -146,7 +153,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio
|
||||||
Inputer: r.Inputer,
|
Inputer: r.Inputer,
|
||||||
}
|
}
|
||||||
r.Unpeeker = &perform.Unpeeker{
|
r.Unpeeker = &perform.Unpeeker{
|
||||||
ServerName: r.Cfg.Matrix.ServerName,
|
ServerName: r.ServerName,
|
||||||
Cfg: r.Cfg,
|
Cfg: r.Cfg,
|
||||||
DB: r.DB,
|
DB: r.DB,
|
||||||
FSAPI: r.fsAPI,
|
FSAPI: r.fsAPI,
|
||||||
|
@ -193,6 +200,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio
|
||||||
|
|
||||||
func (r *RoomserverInternalAPI) SetUserAPI(userAPI userapi.RoomserverUserAPI) {
|
func (r *RoomserverInternalAPI) SetUserAPI(userAPI userapi.RoomserverUserAPI) {
|
||||||
r.Leaver.UserAPI = userAPI
|
r.Leaver.UserAPI = userAPI
|
||||||
|
r.Inputer.UserAPI = userAPI
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *RoomserverInternalAPI) SetAppserviceAPI(asAPI asAPI.AppServiceInternalAPI) {
|
func (r *RoomserverInternalAPI) SetAppserviceAPI(asAPI asAPI.AppServiceInternalAPI) {
|
||||||
|
|
|
@ -23,6 +23,8 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
|
||||||
"github.com/Arceliar/phony"
|
"github.com/Arceliar/phony"
|
||||||
"github.com/getsentry/sentry-go"
|
"github.com/getsentry/sentry-go"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
@ -79,6 +81,7 @@ type Inputer struct {
|
||||||
JetStream nats.JetStreamContext
|
JetStream nats.JetStreamContext
|
||||||
Durable nats.SubOpt
|
Durable nats.SubOpt
|
||||||
ServerName gomatrixserverlib.ServerName
|
ServerName gomatrixserverlib.ServerName
|
||||||
|
SigningIdentity *gomatrixserverlib.SigningIdentity
|
||||||
FSAPI fedapi.RoomserverFederationAPI
|
FSAPI fedapi.RoomserverFederationAPI
|
||||||
KeyRing gomatrixserverlib.JSONVerifier
|
KeyRing gomatrixserverlib.JSONVerifier
|
||||||
ACLs *acls.ServerACLs
|
ACLs *acls.ServerACLs
|
||||||
|
@ -87,6 +90,7 @@ type Inputer struct {
|
||||||
workers sync.Map // room ID -> *worker
|
workers sync.Map // room ID -> *worker
|
||||||
|
|
||||||
Queryer *query.Queryer
|
Queryer *query.Queryer
|
||||||
|
UserAPI userapi.RoomserverUserAPI
|
||||||
}
|
}
|
||||||
|
|
||||||
// If a room consumer is inactive for a while then we will allow NATS
|
// If a room consumer is inactive for a while then we will allow NATS
|
||||||
|
|
|
@ -19,6 +19,7 @@ package input
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
@ -31,6 +32,8 @@ import (
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
userAPI "github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
|
||||||
fedapi "github.com/matrix-org/dendrite/federationapi/api"
|
fedapi "github.com/matrix-org/dendrite/federationapi/api"
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
"github.com/matrix-org/dendrite/internal/eventutil"
|
"github.com/matrix-org/dendrite/internal/eventutil"
|
||||||
|
@ -440,6 +443,13 @@ func (r *Inputer) processRoomEvent(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If guest_access changed and is not can_join, kick all guest users.
|
||||||
|
if event.Type() == gomatrixserverlib.MRoomGuestAccess && gjson.GetBytes(event.Content(), "guest_access").Str != "can_join" {
|
||||||
|
if err = r.kickGuests(ctx, event, roomInfo); err != nil {
|
||||||
|
logrus.WithError(err).Error("failed to kick guest users on m.room.guest_access revocation")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Everything was OK — the latest events updater didn't error and
|
// Everything was OK — the latest events updater didn't error and
|
||||||
// we've sent output events. Finally, generate a hook call.
|
// we've sent output events. Finally, generate a hook call.
|
||||||
hooks.Run(hooks.KindNewEventPersisted, headered)
|
hooks.Run(hooks.KindNewEventPersisted, headered)
|
||||||
|
@ -729,3 +739,98 @@ func (r *Inputer) calculateAndSetState(
|
||||||
succeeded = true
|
succeeded = true
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// kickGuests kicks guests users from m.room.guest_access rooms, if guest access is now prohibited.
|
||||||
|
func (r *Inputer) kickGuests(ctx context.Context, event *gomatrixserverlib.Event, roomInfo *types.RoomInfo) error {
|
||||||
|
membershipNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, true)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
memberEvents, err := r.DB.Events(ctx, membershipNIDs)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
inputEvents := make([]api.InputRoomEvent, 0, len(memberEvents))
|
||||||
|
latestReq := &api.QueryLatestEventsAndStateRequest{
|
||||||
|
RoomID: event.RoomID(),
|
||||||
|
}
|
||||||
|
latestRes := &api.QueryLatestEventsAndStateResponse{}
|
||||||
|
if err = r.Queryer.QueryLatestEventsAndState(ctx, latestReq, latestRes); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
prevEvents := latestRes.LatestEvents
|
||||||
|
for _, memberEvent := range memberEvents {
|
||||||
|
if memberEvent.StateKey() == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
localpart, senderDomain, err := gomatrixserverlib.SplitID('@', *memberEvent.StateKey())
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
accountRes := &userAPI.QueryAccountByLocalpartResponse{}
|
||||||
|
if err = r.UserAPI.QueryAccountByLocalpart(ctx, &userAPI.QueryAccountByLocalpartRequest{
|
||||||
|
Localpart: localpart,
|
||||||
|
ServerName: senderDomain,
|
||||||
|
}, accountRes); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if accountRes.Account == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if accountRes.Account.AccountType != userAPI.AccountTypeGuest {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var memberContent gomatrixserverlib.MemberContent
|
||||||
|
if err = json.Unmarshal(memberEvent.Content(), &memberContent); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
memberContent.Membership = gomatrixserverlib.Leave
|
||||||
|
|
||||||
|
stateKey := *memberEvent.StateKey()
|
||||||
|
fledglingEvent := &gomatrixserverlib.EventBuilder{
|
||||||
|
RoomID: event.RoomID(),
|
||||||
|
Type: gomatrixserverlib.MRoomMember,
|
||||||
|
StateKey: &stateKey,
|
||||||
|
Sender: stateKey,
|
||||||
|
PrevEvents: prevEvents,
|
||||||
|
}
|
||||||
|
|
||||||
|
if fledglingEvent.Content, err = json.Marshal(memberContent); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(fledglingEvent)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
event, err := eventutil.BuildEvent(ctx, fledglingEvent, r.Cfg.Matrix, r.SigningIdentity, time.Now(), &eventsNeeded, latestRes)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
inputEvents = append(inputEvents, api.InputRoomEvent{
|
||||||
|
Kind: api.KindNew,
|
||||||
|
Event: event,
|
||||||
|
Origin: senderDomain,
|
||||||
|
SendAsServer: string(senderDomain),
|
||||||
|
})
|
||||||
|
prevEvents = []gomatrixserverlib.EventReference{
|
||||||
|
event.EventReference(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inputReq := &api.InputRoomEventsRequest{
|
||||||
|
InputRoomEvents: inputEvents,
|
||||||
|
Asynchronous: true, // Needs to be async, as we otherwise create a deadlock
|
||||||
|
}
|
||||||
|
inputRes := &api.InputRoomEventsResponse{}
|
||||||
|
return r.InputRoomEvents(ctx, inputReq, inputRes)
|
||||||
|
}
|
||||||
|
|
|
@ -16,6 +16,7 @@ package perform
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -270,6 +271,28 @@ func (r *Joiner) performJoinRoomByID(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If a guest is trying to join a room, check that the room has a m.room.guest_access event
|
||||||
|
if req.IsGuest {
|
||||||
|
var guestAccessEvent *gomatrixserverlib.HeaderedEvent
|
||||||
|
guestAccess := "forbidden"
|
||||||
|
guestAccessEvent, err = r.DB.GetStateEvent(ctx, req.RoomIDOrAlias, gomatrixserverlib.MRoomGuestAccess, "")
|
||||||
|
if (err != nil && !errors.Is(err, sql.ErrNoRows)) || guestAccessEvent == nil {
|
||||||
|
logrus.WithError(err).Warn("unable to get m.room.guest_access event, defaulting to 'forbidden'")
|
||||||
|
}
|
||||||
|
if guestAccessEvent != nil {
|
||||||
|
guestAccess = gjson.GetBytes(guestAccessEvent.Content(), "guest_access").String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Servers MUST only allow guest users to join rooms if the m.room.guest_access state event
|
||||||
|
// is present on the room and has the guest_access value can_join.
|
||||||
|
if guestAccess != "can_join" {
|
||||||
|
return "", "", &rsAPI.PerformError{
|
||||||
|
Code: rsAPI.PerformErrorNotAllowed,
|
||||||
|
Msg: "Guest access is forbidden",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// If we should do a forced federated join then do that.
|
// If we should do a forced federated join then do that.
|
||||||
var joinedVia gomatrixserverlib.ServerName
|
var joinedVia gomatrixserverlib.ServerName
|
||||||
if forceFederatedJoin {
|
if forceFederatedJoin {
|
||||||
|
|
|
@ -3,18 +3,23 @@ package roomserver_test
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
|
"github.com/matrix-org/dendrite/internal/httputil"
|
||||||
|
"github.com/matrix-org/dendrite/setup/base"
|
||||||
|
"github.com/matrix-org/dendrite/userapi"
|
||||||
|
|
||||||
|
userAPI "github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/httputil"
|
|
||||||
"github.com/matrix-org/dendrite/roomserver"
|
"github.com/matrix-org/dendrite/roomserver"
|
||||||
"github.com/matrix-org/dendrite/roomserver/api"
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
"github.com/matrix-org/dendrite/roomserver/inthttp"
|
"github.com/matrix-org/dendrite/roomserver/inthttp"
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage"
|
"github.com/matrix-org/dendrite/roomserver/storage"
|
||||||
"github.com/matrix-org/dendrite/setup/base"
|
|
||||||
"github.com/matrix-org/dendrite/test"
|
"github.com/matrix-org/dendrite/test"
|
||||||
"github.com/matrix-org/dendrite/test/testrig"
|
"github.com/matrix-org/dendrite/test/testrig"
|
||||||
)
|
)
|
||||||
|
@ -29,7 +34,28 @@ func mustCreateDatabase(t *testing.T, dbType test.DBType) (*base.BaseDendrite, s
|
||||||
return base, db, close
|
return base, db, close
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_SharedUsers(t *testing.T) {
|
func TestUsers(t *testing.T) {
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
base, close := testrig.CreateBaseDendrite(t, dbType)
|
||||||
|
defer close()
|
||||||
|
rsAPI := roomserver.NewInternalAPI(base)
|
||||||
|
// SetFederationAPI starts the room event input consumer
|
||||||
|
rsAPI.SetFederationAPI(nil, nil)
|
||||||
|
|
||||||
|
t.Run("shared users", func(t *testing.T) {
|
||||||
|
testSharedUsers(t, rsAPI)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("kick users", func(t *testing.T) {
|
||||||
|
usrAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, nil, rsAPI, nil)
|
||||||
|
rsAPI.SetUserAPI(usrAPI)
|
||||||
|
testKickUsers(t, rsAPI, usrAPI)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func testSharedUsers(t *testing.T, rsAPI api.RoomserverInternalAPI) {
|
||||||
alice := test.NewUser(t)
|
alice := test.NewUser(t)
|
||||||
bob := test.NewUser(t)
|
bob := test.NewUser(t)
|
||||||
room := test.NewRoom(t, alice, test.RoomPreset(test.PresetTrustedPrivateChat))
|
room := test.NewRoom(t, alice, test.RoomPreset(test.PresetTrustedPrivateChat))
|
||||||
|
@ -43,36 +69,93 @@ func Test_SharedUsers(t *testing.T) {
|
||||||
}, test.WithStateKey(bob.ID))
|
}, test.WithStateKey(bob.ID))
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
|
||||||
base, _, close := mustCreateDatabase(t, dbType)
|
|
||||||
defer close()
|
|
||||||
|
|
||||||
rsAPI := roomserver.NewInternalAPI(base)
|
|
||||||
// SetFederationAPI starts the room event input consumer
|
|
||||||
rsAPI.SetFederationAPI(nil, nil)
|
|
||||||
// Create the room
|
// Create the room
|
||||||
if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil {
|
if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil {
|
||||||
t.Fatalf("failed to send events: %v", err)
|
t.Errorf("failed to send events: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Query the shared users for Alice, there should only be Bob.
|
// Query the shared users for Alice, there should only be Bob.
|
||||||
// This is used by the SyncAPI keychange consumer.
|
// This is used by the SyncAPI keychange consumer.
|
||||||
res := &api.QuerySharedUsersResponse{}
|
res := &api.QuerySharedUsersResponse{}
|
||||||
if err := rsAPI.QuerySharedUsers(ctx, &api.QuerySharedUsersRequest{UserID: alice.ID}, res); err != nil {
|
if err := rsAPI.QuerySharedUsers(ctx, &api.QuerySharedUsersRequest{UserID: alice.ID}, res); err != nil {
|
||||||
t.Fatalf("unable to query known users: %v", err)
|
t.Errorf("unable to query known users: %v", err)
|
||||||
}
|
}
|
||||||
if _, ok := res.UserIDsToCount[bob.ID]; !ok {
|
if _, ok := res.UserIDsToCount[bob.ID]; !ok {
|
||||||
t.Fatalf("expected to find %s in shared users, but didn't: %+v", bob.ID, res.UserIDsToCount)
|
t.Errorf("expected to find %s in shared users, but didn't: %+v", bob.ID, res.UserIDsToCount)
|
||||||
}
|
}
|
||||||
// Also verify that we get the expected result when specifying OtherUserIDs.
|
// Also verify that we get the expected result when specifying OtherUserIDs.
|
||||||
// This is used by the SyncAPI when getting device list changes.
|
// This is used by the SyncAPI when getting device list changes.
|
||||||
if err := rsAPI.QuerySharedUsers(ctx, &api.QuerySharedUsersRequest{UserID: alice.ID, OtherUserIDs: []string{bob.ID}}, res); err != nil {
|
if err := rsAPI.QuerySharedUsers(ctx, &api.QuerySharedUsersRequest{UserID: alice.ID, OtherUserIDs: []string{bob.ID}}, res); err != nil {
|
||||||
t.Fatalf("unable to query known users: %v", err)
|
t.Errorf("unable to query known users: %v", err)
|
||||||
}
|
}
|
||||||
if _, ok := res.UserIDsToCount[bob.ID]; !ok {
|
if _, ok := res.UserIDsToCount[bob.ID]; !ok {
|
||||||
t.Fatalf("expected to find %s in shared users, but didn't: %+v", bob.ID, res.UserIDsToCount)
|
t.Errorf("expected to find %s in shared users, but didn't: %+v", bob.ID, res.UserIDsToCount)
|
||||||
}
|
}
|
||||||
})
|
}
|
||||||
|
|
||||||
|
func testKickUsers(t *testing.T, rsAPI api.RoomserverInternalAPI, usrAPI userAPI.UserInternalAPI) {
|
||||||
|
// Create users and room; Bob is going to be the guest and kicked on revocation of guest access
|
||||||
|
alice := test.NewUser(t, test.WithAccountType(userAPI.AccountTypeUser))
|
||||||
|
bob := test.NewUser(t, test.WithAccountType(userAPI.AccountTypeGuest))
|
||||||
|
|
||||||
|
room := test.NewRoom(t, alice, test.RoomPreset(test.PresetPublicChat), test.GuestsCanJoin(true))
|
||||||
|
|
||||||
|
// Join with the guest user
|
||||||
|
room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{
|
||||||
|
"membership": "join",
|
||||||
|
}, test.WithStateKey(bob.ID))
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create the users in the userapi, so the RSAPI can query the account type later
|
||||||
|
for _, u := range []*test.User{alice, bob} {
|
||||||
|
localpart, serverName, _ := gomatrixserverlib.SplitID('@', u.ID)
|
||||||
|
userRes := &userAPI.PerformAccountCreationResponse{}
|
||||||
|
if err := usrAPI.PerformAccountCreation(ctx, &userAPI.PerformAccountCreationRequest{
|
||||||
|
AccountType: u.AccountType,
|
||||||
|
Localpart: localpart,
|
||||||
|
ServerName: serverName,
|
||||||
|
Password: "someRandomPassword",
|
||||||
|
}, userRes); err != nil {
|
||||||
|
t.Errorf("failed to create account: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the room in the database
|
||||||
|
if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil {
|
||||||
|
t.Errorf("failed to send events: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the membership events BEFORE revoking guest access
|
||||||
|
membershipRes := &api.QueryMembershipsForRoomResponse{}
|
||||||
|
if err := rsAPI.QueryMembershipsForRoom(ctx, &api.QueryMembershipsForRoomRequest{LocalOnly: true, JoinedOnly: true, RoomID: room.ID}, membershipRes); err != nil {
|
||||||
|
t.Errorf("failed to query membership for room: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// revoke guest access
|
||||||
|
revokeEvent := room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomGuestAccess, map[string]string{"guest_access": "forbidden"}, test.WithStateKey(""))
|
||||||
|
if err := api.SendEvents(ctx, rsAPI, api.KindNew, []*gomatrixserverlib.HeaderedEvent{revokeEvent}, "test", "test", "test", nil, false); err != nil {
|
||||||
|
t.Errorf("failed to send events: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Even though we are sending the events sync, the "kickUsers" function is sending the events async, so we need
|
||||||
|
// to loop and wait for the events to be processed by the roomserver.
|
||||||
|
for i := 0; i <= 20; i++ {
|
||||||
|
// Get the membership events AFTER revoking guest access
|
||||||
|
membershipRes2 := &api.QueryMembershipsForRoomResponse{}
|
||||||
|
if err := rsAPI.QueryMembershipsForRoom(ctx, &api.QueryMembershipsForRoomRequest{LocalOnly: true, JoinedOnly: true, RoomID: room.ID}, membershipRes2); err != nil {
|
||||||
|
t.Errorf("failed to query membership for room: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The membership events should NOT match, as Bob (guest user) should now be kicked from the room
|
||||||
|
if !reflect.DeepEqual(membershipRes, membershipRes2) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
time.Sleep(time.Millisecond * 10)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Errorf("memberships didn't change in time")
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_QueryLeftUsers(t *testing.T) {
|
func Test_QueryLeftUsers(t *testing.T) {
|
||||||
|
|
|
@ -174,7 +174,7 @@ func (c *Global) SigningIdentityFor(serverName gomatrixserverlib.ServerName) (*g
|
||||||
return id, nil
|
return id, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("no signing identity %q", serverName)
|
return nil, fmt.Errorf("no signing identity for %q", serverName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Global) SigningIdentities() []*gomatrixserverlib.SigningIdentity {
|
func (c *Global) SigningIdentities() []*gomatrixserverlib.SigningIdentity {
|
||||||
|
|
|
@ -16,8 +16,10 @@ package config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"gopkg.in/yaml.v2"
|
"gopkg.in/yaml.v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -290,3 +292,55 @@ func TestUnmarshalDataUnit(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_SigningIdentityFor(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
virtualHosts []*VirtualHost
|
||||||
|
serverName gomatrixserverlib.ServerName
|
||||||
|
want *gomatrixserverlib.SigningIdentity
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no virtual hosts defined",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no identity found",
|
||||||
|
serverName: gomatrixserverlib.ServerName("doesnotexist"),
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "found identity",
|
||||||
|
serverName: gomatrixserverlib.ServerName("main"),
|
||||||
|
want: &gomatrixserverlib.SigningIdentity{ServerName: "main"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "identity found on virtual hosts",
|
||||||
|
serverName: gomatrixserverlib.ServerName("vh2"),
|
||||||
|
virtualHosts: []*VirtualHost{
|
||||||
|
{SigningIdentity: gomatrixserverlib.SigningIdentity{ServerName: "vh1"}},
|
||||||
|
{SigningIdentity: gomatrixserverlib.SigningIdentity{ServerName: "vh2"}},
|
||||||
|
},
|
||||||
|
want: &gomatrixserverlib.SigningIdentity{ServerName: "vh2"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
c := &Global{
|
||||||
|
VirtualHosts: tt.virtualHosts,
|
||||||
|
SigningIdentity: gomatrixserverlib.SigningIdentity{
|
||||||
|
ServerName: "main",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
got, err := c.SigningIdentityFor(tt.serverName)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("SigningIdentityFor() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("SigningIdentityFor() got = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -49,3 +49,6 @@ Leaves are present in non-gapped incremental syncs
|
||||||
|
|
||||||
# Below test was passing for the wrong reason, failing correctly since #2858
|
# Below test was passing for the wrong reason, failing correctly since #2858
|
||||||
New federated private chats get full presence information (SYN-115)
|
New federated private chats get full presence information (SYN-115)
|
||||||
|
|
||||||
|
# We don't have any state to calculate m.room.guest_access when accepting invites
|
||||||
|
Guest users can accept invites to private rooms over federation
|
|
@ -764,3 +764,6 @@ local user has tags copied to the new room
|
||||||
remote user has tags copied to the new room
|
remote user has tags copied to the new room
|
||||||
/upgrade moves remote aliases to the new room
|
/upgrade moves remote aliases to the new room
|
||||||
Local and remote users' homeservers remove a room from their public directory on upgrade
|
Local and remote users' homeservers remove a room from their public directory on upgrade
|
||||||
|
Guest users denied access over federation if guest access prohibited
|
||||||
|
Guest users are kicked from guest_access rooms on revocation of guest_access
|
||||||
|
Guest users are kicked from guest_access rooms on revocation of guest_access over federation
|
12
test/room.go
12
test/room.go
|
@ -41,6 +41,7 @@ type Room struct {
|
||||||
ID string
|
ID string
|
||||||
Version gomatrixserverlib.RoomVersion
|
Version gomatrixserverlib.RoomVersion
|
||||||
preset Preset
|
preset Preset
|
||||||
|
guestCanJoin bool
|
||||||
visibility gomatrixserverlib.HistoryVisibility
|
visibility gomatrixserverlib.HistoryVisibility
|
||||||
creator *User
|
creator *User
|
||||||
|
|
||||||
|
@ -120,6 +121,11 @@ func (r *Room) insertCreateEvents(t *testing.T) {
|
||||||
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomPowerLevels, plContent, WithStateKey(""))
|
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomPowerLevels, plContent, WithStateKey(""))
|
||||||
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomJoinRules, joinRule, WithStateKey(""))
|
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomJoinRules, joinRule, WithStateKey(""))
|
||||||
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomHistoryVisibility, hisVis, WithStateKey(""))
|
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomHistoryVisibility, hisVis, WithStateKey(""))
|
||||||
|
if r.guestCanJoin {
|
||||||
|
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomGuestAccess, map[string]string{
|
||||||
|
"guest_access": "can_join",
|
||||||
|
}, WithStateKey(""))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create an event in this room but do not insert it. Does not modify the room in any way (depth, fwd extremities, etc) so is thread-safe.
|
// Create an event in this room but do not insert it. Does not modify the room in any way (depth, fwd extremities, etc) so is thread-safe.
|
||||||
|
@ -268,3 +274,9 @@ func RoomVersion(ver gomatrixserverlib.RoomVersion) roomModifier {
|
||||||
r.Version = ver
|
r.Version = ver
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GuestsCanJoin(canJoin bool) roomModifier {
|
||||||
|
return func(t *testing.T, r *Room) {
|
||||||
|
r.guestCanJoin = canJoin
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -50,6 +50,7 @@ type KeyserverUserAPI interface {
|
||||||
|
|
||||||
type RoomserverUserAPI interface {
|
type RoomserverUserAPI interface {
|
||||||
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
|
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
|
||||||
|
QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) (err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// api functions required by the media api
|
// api functions required by the media api
|
||||||
|
@ -671,3 +672,12 @@ type PerformSaveThreePIDAssociationRequest struct {
|
||||||
ServerName gomatrixserverlib.ServerName
|
ServerName gomatrixserverlib.ServerName
|
||||||
Medium string
|
Medium string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type QueryAccountByLocalpartRequest struct {
|
||||||
|
Localpart string
|
||||||
|
ServerName gomatrixserverlib.ServerName
|
||||||
|
}
|
||||||
|
|
||||||
|
type QueryAccountByLocalpartResponse struct {
|
||||||
|
Account *Account
|
||||||
|
}
|
||||||
|
|
|
@ -204,6 +204,12 @@ func (t *UserInternalAPITrace) PerformSaveThreePIDAssociation(ctx context.Contex
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *UserInternalAPITrace) QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) error {
|
||||||
|
err := t.Impl.QueryAccountByLocalpart(ctx, req, res)
|
||||||
|
util.GetLogger(ctx).Infof("QueryAccountByLocalpart req=%+v res=%+v", js(req), js(res))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func js(thing interface{}) string {
|
func js(thing interface{}) string {
|
||||||
b, err := json.Marshal(thing)
|
b, err := json.Marshal(thing)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -548,6 +548,11 @@ func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAc
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *UserInternalAPI) QueryAccountByLocalpart(ctx context.Context, req *api.QueryAccountByLocalpartRequest, res *api.QueryAccountByLocalpartResponse) (err error) {
|
||||||
|
res.Account, err = a.DB.GetAccountByLocalpart(ctx, req.Localpart, req.ServerName)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Return the appservice 'device' or nil if the token is not an appservice. Returns an error if there was a problem
|
// Return the appservice 'device' or nil if the token is not an appservice. Returns an error if there was a problem
|
||||||
// creating a 'device'.
|
// creating a 'device'.
|
||||||
func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appServiceUserID string) (*api.Device, error) {
|
func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appServiceUserID string) (*api.Device, error) {
|
||||||
|
|
|
@ -60,6 +60,7 @@ const (
|
||||||
QueryAccountByPasswordPath = "/userapi/queryAccountByPassword"
|
QueryAccountByPasswordPath = "/userapi/queryAccountByPassword"
|
||||||
QueryLocalpartForThreePIDPath = "/userapi/queryLocalpartForThreePID"
|
QueryLocalpartForThreePIDPath = "/userapi/queryLocalpartForThreePID"
|
||||||
QueryThreePIDsForLocalpartPath = "/userapi/queryThreePIDsForLocalpart"
|
QueryThreePIDsForLocalpartPath = "/userapi/queryThreePIDsForLocalpart"
|
||||||
|
QueryAccountByLocalpartPath = "/userapi/queryAccountType"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API.
|
// NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API.
|
||||||
|
@ -440,3 +441,14 @@ func (h *httpUserInternalAPI) PerformSaveThreePIDAssociation(
|
||||||
h.httpClient, ctx, request, response,
|
h.httpClient, ctx, request, response,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *httpUserInternalAPI) QueryAccountByLocalpart(
|
||||||
|
ctx context.Context,
|
||||||
|
req *api.QueryAccountByLocalpartRequest,
|
||||||
|
res *api.QueryAccountByLocalpartResponse,
|
||||||
|
) error {
|
||||||
|
return httputil.CallInternalRPCAPI(
|
||||||
|
"QueryAccountByLocalpart", h.apiURL+QueryAccountByLocalpartPath,
|
||||||
|
h.httpClient, ctx, req, res,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
|
@ -189,4 +189,9 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI, enableMetrics
|
||||||
PerformSaveThreePIDAssociationPath,
|
PerformSaveThreePIDAssociationPath,
|
||||||
httputil.MakeInternalRPCAPI("UserAPIPerformSaveThreePIDAssociation", enableMetrics, s.PerformSaveThreePIDAssociation),
|
httputil.MakeInternalRPCAPI("UserAPIPerformSaveThreePIDAssociation", enableMetrics, s.PerformSaveThreePIDAssociation),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
internalAPIMux.Handle(
|
||||||
|
QueryAccountByLocalpartPath,
|
||||||
|
httputil.MakeInternalRPCAPI("AccountByLocalpart", enableMetrics, s.QueryAccountByLocalpart),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
|
@ -307,3 +307,64 @@ func TestLoginToken(t *testing.T) {
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestQueryAccountByLocalpart(t *testing.T) {
|
||||||
|
alice := test.NewUser(t)
|
||||||
|
|
||||||
|
localpart, userServername, _ := gomatrixserverlib.SplitID('@', alice.ID)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
intAPI, db, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType)
|
||||||
|
defer close()
|
||||||
|
|
||||||
|
createdAcc, err := db.CreateAccount(ctx, localpart, userServername, "", "", alice.AccountType)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := func(t *testing.T, internalAPI api.UserInternalAPI) {
|
||||||
|
// Query existing account
|
||||||
|
queryAccResp := &api.QueryAccountByLocalpartResponse{}
|
||||||
|
if err = internalAPI.QueryAccountByLocalpart(ctx, &api.QueryAccountByLocalpartRequest{
|
||||||
|
Localpart: localpart,
|
||||||
|
ServerName: userServername,
|
||||||
|
}, queryAccResp); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(createdAcc, queryAccResp.Account) {
|
||||||
|
t.Fatalf("created and queried accounts don't match:\n%+v vs.\n%+v", createdAcc, queryAccResp.Account)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query non-existent account, this should result in an error
|
||||||
|
err = internalAPI.QueryAccountByLocalpart(ctx, &api.QueryAccountByLocalpartRequest{
|
||||||
|
Localpart: "doesnotexist",
|
||||||
|
ServerName: userServername,
|
||||||
|
}, queryAccResp)
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("expected an error, but got none: %+v", queryAccResp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("Monolith", func(t *testing.T) {
|
||||||
|
testCases(t, intAPI)
|
||||||
|
// also test tracing
|
||||||
|
testCases(t, &api.UserInternalAPITrace{Impl: intAPI})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("HTTP API", func(t *testing.T) {
|
||||||
|
router := mux.NewRouter().PathPrefix(httputil.InternalPathPrefix).Subrouter()
|
||||||
|
userapi.AddInternalRoutes(router, intAPI, false)
|
||||||
|
apiURL, cancel := test.ListenAndServe(t, router, false)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
userHTTPApi, err := inthttp.NewUserAPIClient(apiURL, &http.Client{Timeout: time.Second * 5})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create HTTP client: %s", err)
|
||||||
|
}
|
||||||
|
testCases(t, userHTTPApi)
|
||||||
|
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue