Merge branch 'main' of github.com:matrix-org/dendrite into s7evink/fts

This commit is contained in:
Till Faelligen 2022-05-18 15:25:41 +02:00
commit a6f2b46c98
73 changed files with 1317 additions and 951 deletions

View file

@ -20,7 +20,7 @@ import (
"log"
"os"
"github.com/matrix-org/dendrite/internal/test"
"github.com/matrix-org/dendrite/test"
)
const usage = `Usage: %s

84
docs/coverage.md Normal file
View file

@ -0,0 +1,84 @@
---
title: Coverage
parent: Development
permalink: /development/coverage
---
To generate a test coverage report for Sytest, a small patch needs to be applied to the Sytest repository to compile and use the instrumented binary:
```patch
diff --git a/lib/SyTest/Homeserver/Dendrite.pm b/lib/SyTest/Homeserver/Dendrite.pm
index 8f0e209c..ad057e52 100644
--- a/lib/SyTest/Homeserver/Dendrite.pm
+++ b/lib/SyTest/Homeserver/Dendrite.pm
@@ -337,7 +337,7 @@ sub _start_monolith
$output->diag( "Starting monolith server" );
my @command = (
- $self->{bindir} . '/dendrite-monolith-server',
+ $self->{bindir} . '/dendrite-monolith-server', '--test.coverprofile=' . $self->{hs_dir} . '/integrationcover.log', "DEVEL",
'--config', $self->{paths}{config},
'--http-bind-address', $self->{bind_host} . ':' . $self->unsecure_port,
'--https-bind-address', $self->{bind_host} . ':' . $self->secure_port,
diff --git a/scripts/dendrite_sytest.sh b/scripts/dendrite_sytest.sh
index f009332b..7ea79869 100755
--- a/scripts/dendrite_sytest.sh
+++ b/scripts/dendrite_sytest.sh
@@ -34,7 +34,8 @@ export GOBIN=/tmp/bin
echo >&2 "--- Building dendrite from source"
cd /src
mkdir -p $GOBIN
-go install -v ./cmd/dendrite-monolith-server
+# go install -v ./cmd/dendrite-monolith-server
+go test -c -cover -covermode=atomic -o $GOBIN/dendrite-monolith-server -coverpkg "github.com/matrix-org/..." ./cmd/dendrite-monolith-server
go install -v ./cmd/generate-keys
cd -
```
Then run Sytest. This will generate a new file `integrationcover.log` in each server's directory e.g `server-0/integrationcover.log`. To parse it,
ensure your working directory is under the Dendrite repository then run:
```bash
go tool cover -func=/path/to/server-0/integrationcover.log
```
which will produce an output like:
```
...
github.com/matrix-org/util/json.go:83: NewJSONRequestHandler 100.0%
github.com/matrix-org/util/json.go:90: Protect 57.1%
github.com/matrix-org/util/json.go:110: RequestWithLogging 100.0%
github.com/matrix-org/util/json.go:132: MakeJSONAPI 70.0%
github.com/matrix-org/util/json.go:151: respond 61.5%
github.com/matrix-org/util/json.go:180: WithCORSOptions 0.0%
github.com/matrix-org/util/json.go:191: SetCORSHeaders 100.0%
github.com/matrix-org/util/json.go:202: RandomString 100.0%
github.com/matrix-org/util/json.go:210: init 100.0%
github.com/matrix-org/util/unique.go:13: Unique 91.7%
github.com/matrix-org/util/unique.go:48: SortAndUnique 100.0%
github.com/matrix-org/util/unique.go:55: UniqueStrings 100.0%
total: (statements) 53.7%
```
The total coverage for this run is the last line at the bottom. However, this value is misleading because Dendrite can run in many different configurations,
which will never be tested in a single test run (e.g sqlite or postgres, monolith or polylith). To get a more accurate value, additional processing is required
to remove packages which will never be tested and extension MSCs:
```bash
# These commands are all similar but change which package paths are _removed_ from the output.
# For Postgres (monolith)
go tool cover -func=/path/to/server-0/integrationcover.log | grep 'github.com/matrix-org/dendrite' | grep -Ev 'inthttp|sqlite|setup/mscs|api_trace' > coverage.txt
# For Postgres (polylith)
go tool cover -func=/path/to/server-0/integrationcover.log | grep 'github.com/matrix-org/dendrite' | grep -Ev 'sqlite|setup/mscs|api_trace' > coverage.txt
# For SQLite (monolith)
go tool cover -func=/path/to/server-0/integrationcover.log | grep 'github.com/matrix-org/dendrite' | grep -Ev 'inthttp|postgres|setup/mscs|api_trace' > coverage.txt
# For SQLite (polylith)
go tool cover -func=/path/to/server-0/integrationcover.log | grep 'github.com/matrix-org/dendrite' | grep -Ev 'postgres|setup/mscs|api_trace' > coverage.txt
```
A total value can then be calculated using:
```bash
cat coverage.txt | awk -F '\t+' '{x = x + $3} END {print x/NR}'
```
We currently do not have a way to combine Sytest/Complement/Unit Tests into a single coverage report.

View file

@ -12,12 +12,16 @@ import (
// FederationInternalAPI is used to query information from the federation sender.
type FederationInternalAPI interface {
FederationClient
gomatrixserverlib.FederatedStateClient
KeyserverFederationAPI
gomatrixserverlib.KeyDatabase
ClientFederationAPI
RoomserverFederationAPI
QueryServerKeys(ctx context.Context, request *QueryServerKeysRequest, response *QueryServerKeysResponse) error
LookupServerKeys(ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) ([]gomatrixserverlib.ServerKeys, error)
MSC2836EventRelationships(ctx context.Context, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error)
MSC2946Spaces(ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool) (res gomatrixserverlib.MSC2946SpacesResponse, err error)
// Broadcasts an EDU to all servers in rooms we are joined to. Used in the yggdrasil demos.
PerformBroadcastEDU(
@ -60,17 +64,43 @@ type RoomserverFederationAPI interface {
LookupMissingEvents(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error)
}
// FederationClient is a subset of gomatrixserverlib.FederationClient functions which the fedsender
// KeyserverFederationAPI is a subset of gomatrixserverlib.FederationClient functions which the keyserver
// implements as proxy calls, with built-in backoff/retries/etc. Errors returned from functions in
// this interface are of type FederationClientError
type FederationClient interface {
gomatrixserverlib.FederatedStateClient
type KeyserverFederationAPI interface {
GetUserDevices(ctx context.Context, s gomatrixserverlib.ServerName, userID string) (res gomatrixserverlib.RespUserDevices, err error)
ClaimKeys(ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (res gomatrixserverlib.RespClaimKeys, err error)
QueryKeys(ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string) (res gomatrixserverlib.RespQueryKeys, err error)
}
// an interface for gmsl.FederationClient - contains functions called by federationapi only.
type FederationClient interface {
gomatrixserverlib.KeyClient
SendTransaction(ctx context.Context, t gomatrixserverlib.Transaction) (res gomatrixserverlib.RespSend, err error)
// Perform operations
LookupRoomAlias(ctx context.Context, s gomatrixserverlib.ServerName, roomAlias string) (res gomatrixserverlib.RespDirectory, err error)
Peek(ctx context.Context, s gomatrixserverlib.ServerName, roomID, peekID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespPeek, err error)
MakeJoin(ctx context.Context, s gomatrixserverlib.ServerName, roomID, userID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMakeJoin, err error)
SendJoin(ctx context.Context, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (res gomatrixserverlib.RespSendJoin, err error)
MakeLeave(ctx context.Context, s gomatrixserverlib.ServerName, roomID, userID string) (res gomatrixserverlib.RespMakeLeave, err error)
SendLeave(ctx context.Context, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (err error)
SendInviteV2(ctx context.Context, s gomatrixserverlib.ServerName, request gomatrixserverlib.InviteV2Request) (res gomatrixserverlib.RespInviteV2, err error)
GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error)
GetEventAuth(ctx context.Context, s gomatrixserverlib.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (res gomatrixserverlib.RespEventAuth, err error)
GetUserDevices(ctx context.Context, s gomatrixserverlib.ServerName, userID string) (gomatrixserverlib.RespUserDevices, error)
ClaimKeys(ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (gomatrixserverlib.RespClaimKeys, error)
QueryKeys(ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string) (gomatrixserverlib.RespQueryKeys, error)
Backfill(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string) (res gomatrixserverlib.Transaction, err error)
MSC2836EventRelationships(ctx context.Context, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error)
MSC2946Spaces(ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool) (res gomatrixserverlib.MSC2946SpacesResponse, err error)
LookupServerKeys(ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) ([]gomatrixserverlib.ServerKeys, error)
ExchangeThirdPartyInvite(ctx context.Context, s gomatrixserverlib.ServerName, builder gomatrixserverlib.EventBuilder) (err error)
LookupState(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespState, err error)
LookupStateIDs(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error)
LookupMissingEvents(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error)
}
// FederationClientError is returned from FederationClient methods in the event of a problem.

View file

@ -39,7 +39,7 @@ type KeyChangeConsumer struct {
db storage.Database
queues *queue.OutgoingQueues
serverName gomatrixserverlib.ServerName
rsAPI roomserverAPI.RoomserverInternalAPI
rsAPI roomserverAPI.FederationRoomserverAPI
topic string
}
@ -50,7 +50,7 @@ func NewKeyChangeConsumer(
js nats.JetStreamContext,
queues *queue.OutgoingQueues,
store storage.Database,
rsAPI roomserverAPI.RoomserverInternalAPI,
rsAPI roomserverAPI.FederationRoomserverAPI,
) *KeyChangeConsumer {
return &KeyChangeConsumer{
ctx: process.Context(),
@ -120,6 +120,7 @@ func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) bool {
logger.WithError(err).Error("failed to calculate joined rooms for user")
return true
}
// send this key change to all servers who share rooms with this user.
destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true)
if err != nil {

View file

@ -36,7 +36,7 @@ import (
type OutputRoomEventConsumer struct {
ctx context.Context
cfg *config.FederationAPI
rsAPI api.RoomserverInternalAPI
rsAPI api.FederationRoomserverAPI
jetstream nats.JetStreamContext
durable string
db storage.Database
@ -51,7 +51,7 @@ func NewOutputRoomEventConsumer(
js nats.JetStreamContext,
queues *queue.OutgoingQueues,
store storage.Database,
rsAPI api.RoomserverInternalAPI,
rsAPI api.FederationRoomserverAPI,
) *OutputRoomEventConsumer {
return &OutputRoomEventConsumer{
ctx: process.Context(),
@ -89,15 +89,7 @@ func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msg *nats.Msg)
switch output.Type {
case api.OutputTypeNewRoomEvent:
ev := output.NewRoomEvent.Event
if output.NewRoomEvent.RewritesState {
if err := s.db.PurgeRoomState(s.ctx, ev.RoomID()); err != nil {
log.WithError(err).Errorf("roomserver output log: purge room state failure")
return false
}
}
if err := s.processMessage(*output.NewRoomEvent); err != nil {
if err := s.processMessage(*output.NewRoomEvent, output.NewRoomEvent.RewritesState); err != nil {
// panic rather than continue with an inconsistent database
log.WithFields(log.Fields{
"event_id": ev.EventID(),
@ -145,7 +137,7 @@ func (s *OutputRoomEventConsumer) processInboundPeek(orp api.OutputNewInboundPee
// processMessage updates the list of currently joined hosts in the room
// and then sends the event to the hosts that were joined before the event.
func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent) error {
func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent, rewritesState bool) error {
addsStateEvents, missingEventIDs := ore.NeededStateEventIDs()
// Ask the roomserver and add in the rest of the results into the set.
@ -164,7 +156,7 @@ func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent) err
addsStateEvents = append(addsStateEvents, eventsRes.Events...)
}
addsJoinedHosts, err := joinedHostsFromEvents(gomatrixserverlib.UnwrapEventHeaders(addsStateEvents))
addsJoinedHosts, err := JoinedHostsFromEvents(gomatrixserverlib.UnwrapEventHeaders(addsStateEvents))
if err != nil {
return err
}
@ -176,10 +168,9 @@ func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent) err
oldJoinedHosts, err := s.db.UpdateRoom(
s.ctx,
ore.Event.RoomID(),
ore.LastSentEventID,
ore.Event.EventID(),
addsJoinedHosts,
ore.RemovesStateEventIDs,
rewritesState, // if we're re-writing state, nuke all joined hosts before adding
)
if err != nil {
return err
@ -238,7 +229,7 @@ func (s *OutputRoomEventConsumer) joinedHostsAtEvent(
return nil, err
}
combinedAddsJoinedHosts, err := joinedHostsFromEvents(combinedAddsEvents)
combinedAddsJoinedHosts, err := JoinedHostsFromEvents(combinedAddsEvents)
if err != nil {
return nil, err
}
@ -284,10 +275,10 @@ func (s *OutputRoomEventConsumer) joinedHostsAtEvent(
return result, nil
}
// joinedHostsFromEvents turns a list of state events into a list of joined hosts.
// JoinedHostsFromEvents turns a list of state events into a list of joined hosts.
// This errors if one of the events was invalid.
// It should be impossible for an invalid event to get this far in the pipeline.
func joinedHostsFromEvents(evs []*gomatrixserverlib.Event) ([]types.JoinedHost, error) {
func JoinedHostsFromEvents(evs []*gomatrixserverlib.Event) ([]types.JoinedHost, error) {
var joinedHosts []types.JoinedHost
for _, ev := range evs {
if ev.Type() != "m.room.member" || ev.StateKey() == nil {

View file

@ -93,8 +93,8 @@ func AddPublicRoutes(
// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes.
func NewInternalAPI(
base *base.BaseDendrite,
federation *gomatrixserverlib.FederationClient,
rsAPI roomserverAPI.RoomserverInternalAPI,
federation api.FederationClient,
rsAPI roomserverAPI.FederationRoomserverAPI,
caches *caching.Caches,
keyRing *gomatrixserverlib.KeyRing,
resetBlacklist bool,

View file

@ -3,18 +3,250 @@ package federationapi_test
import (
"context"
"crypto/ed25519"
"encoding/json"
"fmt"
"strings"
"testing"
"time"
"github.com/matrix-org/dendrite/federationapi"
"github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/federationapi/internal"
"github.com/matrix-org/dendrite/internal/test"
keyapi "github.com/matrix-org/dendrite/keyserver/api"
rsapi "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig"
"github.com/matrix-org/gomatrix"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
)
type fedRoomserverAPI struct {
rsapi.FederationRoomserverAPI
inputRoomEvents func(ctx context.Context, req *rsapi.InputRoomEventsRequest, res *rsapi.InputRoomEventsResponse)
queryRoomsForUser func(ctx context.Context, req *rsapi.QueryRoomsForUserRequest, res *rsapi.QueryRoomsForUserResponse) error
}
// PerformJoin will call this function
func (f *fedRoomserverAPI) InputRoomEvents(ctx context.Context, req *rsapi.InputRoomEventsRequest, res *rsapi.InputRoomEventsResponse) {
if f.inputRoomEvents == nil {
return
}
f.inputRoomEvents(ctx, req, res)
}
// keychange consumer calls this
func (f *fedRoomserverAPI) QueryRoomsForUser(ctx context.Context, req *rsapi.QueryRoomsForUserRequest, res *rsapi.QueryRoomsForUserResponse) error {
if f.queryRoomsForUser == nil {
return nil
}
return f.queryRoomsForUser(ctx, req, res)
}
// TODO: This struct isn't generic, only works for TestFederationAPIJoinThenKeyUpdate
type fedClient struct {
api.FederationClient
allowJoins []*test.Room
keys map[gomatrixserverlib.ServerName]struct {
key ed25519.PrivateKey
keyID gomatrixserverlib.KeyID
}
t *testing.T
sentTxn bool
}
func (f *fedClient) GetServerKeys(ctx context.Context, matrixServer gomatrixserverlib.ServerName) (gomatrixserverlib.ServerKeys, error) {
fmt.Println("GetServerKeys:", matrixServer)
var keys gomatrixserverlib.ServerKeys
var keyID gomatrixserverlib.KeyID
var pkey ed25519.PrivateKey
for srv, data := range f.keys {
if srv == matrixServer {
pkey = data.key
keyID = data.keyID
break
}
}
if pkey == nil {
return keys, nil
}
keys.ServerName = matrixServer
keys.ValidUntilTS = gomatrixserverlib.AsTimestamp(time.Now().Add(10 * time.Hour))
publicKey := pkey.Public().(ed25519.PublicKey)
keys.VerifyKeys = map[gomatrixserverlib.KeyID]gomatrixserverlib.VerifyKey{
keyID: {
Key: gomatrixserverlib.Base64Bytes(publicKey),
},
}
toSign, err := json.Marshal(keys.ServerKeyFields)
if err != nil {
return keys, err
}
keys.Raw, err = gomatrixserverlib.SignJSON(
string(matrixServer), keyID, pkey, toSign,
)
if err != nil {
return keys, err
}
return keys, nil
}
func (f *fedClient) MakeJoin(ctx context.Context, s gomatrixserverlib.ServerName, roomID, userID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMakeJoin, err error) {
for _, r := range f.allowJoins {
if r.ID == roomID {
res.RoomVersion = r.Version
res.JoinEvent = gomatrixserverlib.EventBuilder{
Sender: userID,
RoomID: roomID,
Type: "m.room.member",
StateKey: &userID,
Content: gomatrixserverlib.RawJSON([]byte(`{"membership":"join"}`)),
PrevEvents: r.ForwardExtremities(),
}
var needed gomatrixserverlib.StateNeeded
needed, err = gomatrixserverlib.StateNeededForEventBuilder(&res.JoinEvent)
if err != nil {
f.t.Errorf("StateNeededForEventBuilder: %v", err)
return
}
res.JoinEvent.AuthEvents = r.MustGetAuthEventRefsForEvent(f.t, needed)
return
}
}
return
}
func (f *fedClient) SendJoin(ctx context.Context, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (res gomatrixserverlib.RespSendJoin, err error) {
for _, r := range f.allowJoins {
if r.ID == event.RoomID() {
r.InsertEvent(f.t, event.Headered(r.Version))
f.t.Logf("Join event: %v", event.EventID())
res.StateEvents = gomatrixserverlib.NewEventJSONsFromHeaderedEvents(r.CurrentState())
res.AuthEvents = gomatrixserverlib.NewEventJSONsFromHeaderedEvents(r.Events())
}
}
return
}
func (f *fedClient) SendTransaction(ctx context.Context, t gomatrixserverlib.Transaction) (res gomatrixserverlib.RespSend, err error) {
for _, edu := range t.EDUs {
if edu.Type == gomatrixserverlib.MDeviceListUpdate {
f.sentTxn = true
}
}
f.t.Logf("got /send")
return
}
// Regression test to make sure that /send_join is updating the destination hosts synchronously and
// isn't relying on the roomserver.
func TestFederationAPIJoinThenKeyUpdate(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
testFederationAPIJoinThenKeyUpdate(t, dbType)
})
}
func testFederationAPIJoinThenKeyUpdate(t *testing.T, dbType test.DBType) {
base, close := testrig.CreateBaseDendrite(t, dbType)
base.Cfg.FederationAPI.PreferDirectFetch = true
defer close()
jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream)
defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream)
serverA := gomatrixserverlib.ServerName("server.a")
serverAKeyID := gomatrixserverlib.KeyID("ed25519:servera")
serverAPrivKey := test.PrivateKeyA
creator := test.NewUser(t, test.WithSigningServer(serverA, serverAKeyID, serverAPrivKey))
myServer := base.Cfg.Global.ServerName
myServerKeyID := base.Cfg.Global.KeyID
myServerPrivKey := base.Cfg.Global.PrivateKey
joiningUser := test.NewUser(t, test.WithSigningServer(myServer, myServerKeyID, myServerPrivKey))
fmt.Printf("creator: %v joining user: %v\n", creator.ID, joiningUser.ID)
room := test.NewRoom(t, creator)
rsapi := &fedRoomserverAPI{
inputRoomEvents: func(ctx context.Context, req *rsapi.InputRoomEventsRequest, res *rsapi.InputRoomEventsResponse) {
if req.Asynchronous {
t.Errorf("InputRoomEvents from PerformJoin MUST be synchronous")
}
},
queryRoomsForUser: func(ctx context.Context, req *rsapi.QueryRoomsForUserRequest, res *rsapi.QueryRoomsForUserResponse) error {
if req.UserID == joiningUser.ID && req.WantMembership == "join" {
res.RoomIDs = []string{room.ID}
return nil
}
return fmt.Errorf("unexpected queryRoomsForUser: %+v", *req)
},
}
fc := &fedClient{
allowJoins: []*test.Room{room},
t: t,
keys: map[gomatrixserverlib.ServerName]struct {
key ed25519.PrivateKey
keyID gomatrixserverlib.KeyID
}{
serverA: {
key: serverAPrivKey,
keyID: serverAKeyID,
},
myServer: {
key: myServerPrivKey,
keyID: myServerKeyID,
},
},
}
fsapi := federationapi.NewInternalAPI(base, fc, rsapi, base.Caches, nil, false)
var resp api.PerformJoinResponse
fsapi.PerformJoin(context.Background(), &api.PerformJoinRequest{
RoomID: room.ID,
UserID: joiningUser.ID,
ServerNames: []gomatrixserverlib.ServerName{serverA},
}, &resp)
if resp.JoinedVia != serverA {
t.Errorf("PerformJoin: joined via %v want %v", resp.JoinedVia, serverA)
}
if resp.LastError != nil {
t.Fatalf("PerformJoin: returned error: %+v", *resp.LastError)
}
// Inject a keyserver key change event and ensure we try to send it out. If we don't, then the
// federationapi is incorrectly waiting for an output room event to arrive to update the joined
// hosts table.
key := keyapi.DeviceMessage{
Type: keyapi.TypeDeviceKeyUpdate,
DeviceKeys: &keyapi.DeviceKeys{
UserID: joiningUser.ID,
DeviceID: "MY_DEVICE",
DisplayName: "BLARGLE",
KeyJSON: []byte(`{}`),
},
}
b, err := json.Marshal(key)
if err != nil {
t.Fatalf("Failed to marshal device message: %s", err)
}
msg := &nats.Msg{
Subject: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputKeyChangeEvent),
Header: nats.Header{},
Data: b,
}
msg.Header.Set(jetstream.UserID, key.UserID)
testrig.MustPublishMsgs(t, jsctx, msg)
time.Sleep(500 * time.Millisecond)
if !fc.sentTxn {
t.Fatalf("did not send device list update")
}
}
// Tests that event IDs with '/' in them (escaped as %2F) are correctly passed to the right handler and don't 404.
// Relevant for v3 rooms and a cause of flakey sytests as the IDs are randomly generated.
func TestRoomsV3URLEscapeDoNot404(t *testing.T) {
@ -86,7 +318,7 @@ func TestRoomsV3URLEscapeDoNot404(t *testing.T) {
}
gerr, ok := err.(gomatrix.HTTPError)
if !ok {
t.Errorf("failed to cast response error as gomatrix.HTTPError")
t.Errorf("failed to cast response error as gomatrix.HTTPError: %s", err)
continue
}
t.Logf("Error: %+v", gerr)

View file

@ -25,8 +25,8 @@ type FederationInternalAPI struct {
db storage.Database
cfg *config.FederationAPI
statistics *statistics.Statistics
rsAPI roomserverAPI.RoomserverInternalAPI
federation *gomatrixserverlib.FederationClient
rsAPI roomserverAPI.FederationRoomserverAPI
federation api.FederationClient
keyRing *gomatrixserverlib.KeyRing
queues *queue.OutgoingQueues
joins sync.Map // joins currently in progress
@ -34,8 +34,8 @@ type FederationInternalAPI struct {
func NewFederationInternalAPI(
db storage.Database, cfg *config.FederationAPI,
rsAPI roomserverAPI.RoomserverInternalAPI,
federation *gomatrixserverlib.FederationClient,
rsAPI roomserverAPI.FederationRoomserverAPI,
federation api.FederationClient,
statistics *statistics.Statistics,
caches *caching.Caches,
queues *queue.OutgoingQueues,

View file

@ -8,6 +8,7 @@ import (
"time"
"github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/federationapi/consumers"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/version"
"github.com/matrix-org/gomatrix"
@ -235,6 +236,21 @@ func (r *FederationInternalAPI) performJoinUsingServer(
return fmt.Errorf("respSendJoin.Check: %w", err)
}
// We need to immediately update our list of joined hosts for this room now as we are technically
// joined. We must do this synchronously: we cannot rely on the roomserver output events as they
// will happen asyncly. If we don't update this table, you can end up with bad failure modes like
// joining a room, waiting for 200 OK then changing device keys and have those keys not be sent
// to other servers (this was a cause of a flakey sytest "Local device key changes get to remote servers")
// The events are trusted now as we performed auth checks above.
joinedHosts, err := consumers.JoinedHostsFromEvents(respState.StateEvents.TrustedEvents(respMakeJoin.RoomVersion, false))
if err != nil {
return fmt.Errorf("JoinedHostsFromEvents: failed to get joined hosts: %s", err)
}
logrus.WithField("hosts", joinedHosts).WithField("room", roomID).Info("Joined federated room with hosts")
if _, err = r.db.UpdateRoom(context.Background(), roomID, joinedHosts, nil, true); err != nil {
return fmt.Errorf("UpdatedRoom: failed to update room with joined hosts: %s", err)
}
// If we successfully performed a send_join above then the other
// server now thinks we're a part of the room. Send the newly
// returned state to the roomserver to update our local view.
@ -650,7 +666,7 @@ func setDefaultRoomVersionFromJoinEvent(joinEvent gomatrixserverlib.EventBuilder
// FederatedAuthProvider is an auth chain provider which fetches events from the server provided
func federatedAuthProvider(
ctx context.Context, federation *gomatrixserverlib.FederationClient,
ctx context.Context, federation api.FederationClient,
keyRing gomatrixserverlib.JSONVerifier, server gomatrixserverlib.ServerName,
) gomatrixserverlib.AuthChainProvider {
// A list of events that we have retried, if they were not included in

View file

@ -21,6 +21,7 @@ import (
"sync"
"time"
fedapi "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/federationapi/statistics"
"github.com/matrix-org/dendrite/federationapi/storage"
"github.com/matrix-org/dendrite/federationapi/storage/shared"
@ -49,8 +50,8 @@ type destinationQueue struct {
db storage.Database
process *process.ProcessContext
signing *SigningInfo
rsAPI api.RoomserverInternalAPI
client *gomatrixserverlib.FederationClient // federation client
rsAPI api.FederationRoomserverAPI
client fedapi.FederationClient // federation client
origin gomatrixserverlib.ServerName // origin of requests
destination gomatrixserverlib.ServerName // destination of requests
running atomic.Bool // is the queue worker running?

View file

@ -26,6 +26,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
fedapi "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/federationapi/statistics"
"github.com/matrix-org/dendrite/federationapi/storage"
"github.com/matrix-org/dendrite/federationapi/storage/shared"
@ -39,9 +40,9 @@ type OutgoingQueues struct {
db storage.Database
process *process.ProcessContext
disabled bool
rsAPI api.RoomserverInternalAPI
rsAPI api.FederationRoomserverAPI
origin gomatrixserverlib.ServerName
client *gomatrixserverlib.FederationClient
client fedapi.FederationClient
statistics *statistics.Statistics
signing *SigningInfo
queuesMutex sync.Mutex // protects the below
@ -85,8 +86,8 @@ func NewOutgoingQueues(
process *process.ProcessContext,
disabled bool,
origin gomatrixserverlib.ServerName,
client *gomatrixserverlib.FederationClient,
rsAPI api.RoomserverInternalAPI,
client fedapi.FederationClient,
rsAPI api.FederationRoomserverAPI,
statistics *statistics.Statistics,
signing *SigningInfo,
) *OutgoingQueues {

View file

@ -30,7 +30,7 @@ import (
// RoomAliasToID converts the queried alias into a room ID and returns it
func RoomAliasToID(
httpReq *http.Request,
federation *gomatrixserverlib.FederationClient,
federation federationAPI.FederationClient,
cfg *config.FederationAPI,
rsAPI roomserverAPI.FederationRoomserverAPI,
senderAPI federationAPI.FederationInternalAPI,

View file

@ -54,7 +54,7 @@ func Setup(
rsAPI roomserverAPI.FederationRoomserverAPI,
fsAPI *fedInternal.FederationInternalAPI,
keys gomatrixserverlib.JSONVerifier,
federation *gomatrixserverlib.FederationClient,
federation federationAPI.FederationClient,
userAPI userapi.FederationUserAPI,
keyAPI keyserverAPI.FederationKeyAPI,
mscCfg *config.MSCs,

View file

@ -85,7 +85,7 @@ func Send(
rsAPI api.FederationRoomserverAPI,
keyAPI keyapi.FederationKeyAPI,
keys gomatrixserverlib.JSONVerifier,
federation *gomatrixserverlib.FederationClient,
federation federationAPI.FederationClient,
mu *internal.MutexByRoom,
servers federationAPI.ServersInRoomProvider,
producer *producers.SyncAPIProducer,

View file

@ -8,8 +8,8 @@ import (
"time"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/test"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/gomatrixserverlib"
)

View file

@ -23,6 +23,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
federationAPI "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api"
@ -57,7 +58,7 @@ var (
func CreateInvitesFrom3PIDInvites(
req *http.Request, rsAPI api.FederationRoomserverAPI,
cfg *config.FederationAPI,
federation *gomatrixserverlib.FederationClient,
federation federationAPI.FederationClient,
userAPI userapi.FederationUserAPI,
) util.JSONResponse {
var body invites
@ -107,7 +108,7 @@ func ExchangeThirdPartyInvite(
roomID string,
rsAPI api.FederationRoomserverAPI,
cfg *config.FederationAPI,
federation *gomatrixserverlib.FederationClient,
federation federationAPI.FederationClient,
) util.JSONResponse {
var builder gomatrixserverlib.EventBuilder
if err := json.Unmarshal(request.Content(), &builder); err != nil {
@ -165,7 +166,12 @@ func ExchangeThirdPartyInvite(
// Ask the requesting server to sign the newly created event so we know it
// acknowledged it
signedEvent, err := federation.SendInvite(httpReq.Context(), request.Origin(), event)
inviteReq, err := gomatrixserverlib.NewInviteV2Request(event.Headered(verRes.RoomVersion), nil)
if err != nil {
util.GetLogger(httpReq.Context()).WithError(err).Error("failed to make invite v2 request")
return jsonerror.InternalServerError()
}
signedEvent, err := federation.SendInviteV2(httpReq.Context(), request.Origin(), inviteReq)
if err != nil {
util.GetLogger(httpReq.Context()).WithError(err).Error("federation.SendInvite failed")
return jsonerror.InternalServerError()
@ -205,7 +211,7 @@ func ExchangeThirdPartyInvite(
func createInviteFrom3PIDInvite(
ctx context.Context, rsAPI api.FederationRoomserverAPI,
cfg *config.FederationAPI,
inv invite, federation *gomatrixserverlib.FederationClient,
inv invite, federation federationAPI.FederationClient,
userAPI userapi.FederationUserAPI,
) (*gomatrixserverlib.Event, error) {
verReq := api.QueryRoomVersionForRoomRequest{RoomID: inv.RoomID}
@ -335,7 +341,7 @@ func buildMembershipEvent(
// them responded with an error.
func sendToRemoteServer(
ctx context.Context, inv invite,
federation *gomatrixserverlib.FederationClient, _ *config.FederationAPI,
federation federationAPI.FederationClient, _ *config.FederationAPI,
builder gomatrixserverlib.EventBuilder,
) (err error) {
remoteServers := make([]gomatrixserverlib.ServerName, 2)

View file

@ -25,13 +25,12 @@ import (
type Database interface {
gomatrixserverlib.KeyDatabase
UpdateRoom(ctx context.Context, roomID, oldEventID, newEventID string, addHosts []types.JoinedHost, removeHosts []string) (joinedHosts []types.JoinedHost, err error)
UpdateRoom(ctx context.Context, roomID string, addHosts []types.JoinedHost, removeHosts []string, purgeRoomFirst bool) (joinedHosts []types.JoinedHost, err error)
GetJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error)
GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
// GetJoinedHostsForRooms returns the complete set of servers in the rooms given.
GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf bool) ([]gomatrixserverlib.ServerName, error)
PurgeRoomState(ctx context.Context, roomID string) error
StoreJSON(ctx context.Context, js string) (*shared.Receipt, error)

View file

@ -63,11 +63,21 @@ func (r *Receipt) String() string {
// this isn't a duplicate message.
func (d *Database) UpdateRoom(
ctx context.Context,
roomID, oldEventID, newEventID string,
roomID string,
addHosts []types.JoinedHost,
removeHosts []string,
purgeRoomFirst bool,
) (joinedHosts []types.JoinedHost, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
if purgeRoomFirst {
// If the event is a create event then we'll delete all of the existing
// data for the room. The only reason that a create event would be replayed
// to us in this way is if we're about to receive the entire room state.
if err = d.FederationJoinedHosts.DeleteJoinedHostsForRoom(ctx, txn, roomID); err != nil {
return fmt.Errorf("d.FederationJoinedHosts.DeleteJoinedHosts: %w", err)
}
}
joinedHosts, err = d.FederationJoinedHosts.SelectJoinedHostsWithTx(ctx, txn, roomID)
if err != nil {
return err
@ -138,20 +148,6 @@ func (d *Database) StoreJSON(
}, nil
}
func (d *Database) PurgeRoomState(
ctx context.Context, roomID string,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
// If the event is a create event then we'll delete all of the existing
// data for the room. The only reason that a create event would be replayed
// to us in this way is if we're about to receive the entire room state.
if err := d.FederationJoinedHosts.DeleteJoinedHostsForRoom(ctx, txn, roomID); err != nil {
return fmt.Errorf("d.FederationJoinedHosts.DeleteJoinedHosts: %w", err)
}
return nil
})
}
func (d *Database) AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationBlacklist.InsertBlacklist(context.TODO(), txn, serverName)

View file

@ -20,7 +20,7 @@ import (
"testing"
"time"
"github.com/matrix-org/dendrite/internal/test"
"github.com/matrix-org/dendrite/test"
)
func TestEDUCache(t *testing.T) {

View file

@ -1,158 +0,0 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package test
import (
"crypto/tls"
"fmt"
"io"
"io/ioutil"
"net/http"
"sync"
"time"
"github.com/matrix-org/gomatrixserverlib"
)
// Request contains the information necessary to issue a request and test its result
type Request struct {
Req *http.Request
WantedBody string
WantedStatusCode int
LastErr *LastRequestErr
}
// LastRequestErr is a synchronised error wrapper
// Useful for obtaining the last error from a set of requests
type LastRequestErr struct {
sync.Mutex
Err error
}
// Set sets the error
func (r *LastRequestErr) Set(err error) {
r.Lock()
defer r.Unlock()
r.Err = err
}
// Get gets the error
func (r *LastRequestErr) Get() error {
r.Lock()
defer r.Unlock()
return r.Err
}
// CanonicalJSONInput canonicalises a slice of JSON strings
// Useful for test input
func CanonicalJSONInput(jsonData []string) []string {
for i := range jsonData {
jsonBytes, err := gomatrixserverlib.CanonicalJSON([]byte(jsonData[i]))
if err != nil && err != io.EOF {
panic(err)
}
jsonData[i] = string(jsonBytes)
}
return jsonData
}
// Do issues a request and checks the status code and body of the response
func (r *Request) Do() (err error) {
client := &http.Client{
Timeout: 5 * time.Second,
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
},
}
res, err := client.Do(r.Req)
if err != nil {
return err
}
defer (func() { err = res.Body.Close() })()
if res.StatusCode != r.WantedStatusCode {
return fmt.Errorf("incorrect status code. Expected: %d Got: %d", r.WantedStatusCode, res.StatusCode)
}
if r.WantedBody != "" {
resBytes, err := ioutil.ReadAll(res.Body)
if err != nil {
return err
}
jsonBytes, err := gomatrixserverlib.CanonicalJSON(resBytes)
if err != nil {
return err
}
if string(jsonBytes) != r.WantedBody {
return fmt.Errorf("returned wrong bytes. Expected:\n%s\n\nGot:\n%s", r.WantedBody, string(jsonBytes))
}
}
return nil
}
// DoUntilSuccess blocks and repeats the same request until the response returns the desired status code and body.
// It then closes the given channel and returns.
func (r *Request) DoUntilSuccess(done chan error) {
r.LastErr = &LastRequestErr{}
for {
if err := r.Do(); err != nil {
r.LastErr.Set(err)
time.Sleep(1 * time.Second) // don't tightloop
continue
}
close(done)
return
}
}
// Run repeatedly issues a request until success, error or a timeout is reached
func (r *Request) Run(label string, timeout time.Duration, serverCmdChan chan error) {
fmt.Printf("==TESTING== %v (timeout: %v)\n", label, timeout)
done := make(chan error, 1)
// We need to wait for the server to:
// - have connected to the database
// - have created the tables
// - be listening on the given port
go r.DoUntilSuccess(done)
// wait for one of:
// - the test to pass (done channel is closed)
// - the server to exit with an error (error sent on serverCmdChan)
// - our test timeout to expire
// We don't need to clean up since the main() function handles that in the event we panic
select {
case <-time.After(timeout):
fmt.Printf("==TESTING== %v TIMEOUT\n", label)
if reqErr := r.LastErr.Get(); reqErr != nil {
fmt.Println("Last /sync request error:")
fmt.Println(reqErr)
}
panic(fmt.Sprintf("%v server timed out", label))
case err := <-serverCmdChan:
if err != nil {
fmt.Println("=============================================================================================")
fmt.Printf("%v server failed to run. If failing with 'pq: password authentication failed for user' try:", label)
fmt.Println(" export PGHOST=/var/run/postgresql")
fmt.Println("=============================================================================================")
panic(err)
}
case <-done:
fmt.Printf("==TESTING== %v PASSED\n", label)
}
}

View file

@ -1,76 +0,0 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package test
import (
"io"
"os/exec"
"path/filepath"
"strings"
)
// KafkaExecutor executes kafka scripts.
type KafkaExecutor struct {
// The location of Zookeeper. Typically this is `localhost:2181`.
ZookeeperURI string
// The directory where Kafka is installed to. Used to locate kafka scripts.
KafkaDirectory string
// The location of the Kafka logs. Typically this is `localhost:9092`.
KafkaURI string
// Where stdout and stderr should be written to. Typically this is `os.Stderr`.
OutputWriter io.Writer
}
// CreateTopic creates a new kafka topic. This is created with a single partition.
func (e *KafkaExecutor) CreateTopic(topic string) error {
cmd := exec.Command(
filepath.Join(e.KafkaDirectory, "bin", "kafka-topics.sh"),
"--create",
"--zookeeper", e.ZookeeperURI,
"--replication-factor", "1",
"--partitions", "1",
"--topic", topic,
)
cmd.Stdout = e.OutputWriter
cmd.Stderr = e.OutputWriter
return cmd.Run()
}
// WriteToTopic writes data to a kafka topic.
func (e *KafkaExecutor) WriteToTopic(topic string, data []string) error {
cmd := exec.Command(
filepath.Join(e.KafkaDirectory, "bin", "kafka-console-producer.sh"),
"--broker-list", e.KafkaURI,
"--topic", topic,
)
cmd.Stdout = e.OutputWriter
cmd.Stderr = e.OutputWriter
cmd.Stdin = strings.NewReader(strings.Join(data, "\n"))
return cmd.Run()
}
// DeleteTopic deletes a given kafka topic if it exists.
func (e *KafkaExecutor) DeleteTopic(topic string) error {
cmd := exec.Command(
filepath.Join(e.KafkaDirectory, "bin", "kafka-topics.sh"),
"--delete",
"--if-exists",
"--zookeeper", e.ZookeeperURI,
"--topic", topic,
)
cmd.Stderr = e.OutputWriter
cmd.Stdout = e.OutputWriter
return cmd.Run()
}

View file

@ -1,152 +0,0 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package test
import (
"context"
"fmt"
"net"
"net/http"
"os"
"os/exec"
"path/filepath"
"strings"
"sync"
"testing"
"github.com/matrix-org/dendrite/setup/config"
)
// Defaulting allows assignment of string variables with a fallback default value
// Useful for use with os.Getenv() for example
func Defaulting(value, defaultValue string) string {
if value == "" {
value = defaultValue
}
return value
}
// CreateDatabase creates a new database, dropping it first if it exists
func CreateDatabase(command string, args []string, database string) error {
cmd := exec.Command(command, args...)
cmd.Stdin = strings.NewReader(
fmt.Sprintf("DROP DATABASE IF EXISTS %s; CREATE DATABASE %s;", database, database),
)
// Send stdout and stderr to our stderr so that we see error messages from
// the psql process
cmd.Stdout = os.Stderr
cmd.Stderr = os.Stderr
return cmd.Run()
}
// CreateBackgroundCommand creates an executable command
// The Cmd being executed is returned. A channel is also returned,
// which will have any termination errors sent down it, followed immediately by the channel being closed.
func CreateBackgroundCommand(command string, args []string) (*exec.Cmd, chan error) {
cmd := exec.Command(command, args...)
cmd.Stderr = os.Stderr
cmd.Stdout = os.Stderr
if err := cmd.Start(); err != nil {
panic("failed to start server: " + err.Error())
}
cmdChan := make(chan error, 1)
go func() {
cmdChan <- cmd.Wait()
close(cmdChan)
}()
return cmd, cmdChan
}
// InitDatabase creates the database and config file needed for the server to run
func InitDatabase(postgresDatabase, postgresContainerName string, databases []string) {
if len(databases) > 0 {
var dbCmd string
var dbArgs []string
if postgresContainerName == "" {
dbCmd = "psql"
dbArgs = []string{postgresDatabase}
} else {
dbCmd = "docker"
dbArgs = []string{
"exec", "-i", postgresContainerName, "psql", "-U", "postgres", postgresDatabase,
}
}
for _, database := range databases {
if err := CreateDatabase(dbCmd, dbArgs, database); err != nil {
panic(err)
}
}
}
}
// StartProxy creates a reverse proxy
func StartProxy(bindAddr string, cfg *config.Dendrite) (*exec.Cmd, chan error) {
proxyArgs := []string{
"--bind-address", bindAddr,
"--sync-api-server-url", "http://" + string(cfg.SyncAPI.InternalAPI.Connect),
"--client-api-server-url", "http://" + string(cfg.ClientAPI.InternalAPI.Connect),
"--media-api-server-url", "http://" + string(cfg.MediaAPI.InternalAPI.Connect),
"--tls-cert", "server.crt",
"--tls-key", "server.key",
}
return CreateBackgroundCommand(
filepath.Join(filepath.Dir(os.Args[0]), "client-api-proxy"),
proxyArgs,
)
}
// ListenAndServe will listen on a random high-numbered port and attach the given router.
// Returns the base URL to send requests to. Call `cancel` to shutdown the server, which will block until it has closed.
func ListenAndServe(t *testing.T, router http.Handler, useTLS bool) (apiURL string, cancel func()) {
listener, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatalf("failed to listen: %s", err)
}
port := listener.Addr().(*net.TCPAddr).Port
srv := http.Server{}
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
srv.Handler = router
var err error
if useTLS {
certFile := filepath.Join(os.TempDir(), "dendrite.cert")
keyFile := filepath.Join(os.TempDir(), "dendrite.key")
err = NewTLSKey(keyFile, certFile)
if err != nil {
t.Logf("failed to generate tls key/cert: %s", err)
return
}
err = srv.ServeTLS(listener, certFile, keyFile)
} else {
err = srv.Serve(listener)
}
if err != nil && err != http.ErrServerClosed {
t.Logf("Listen failed: %s", err)
}
}()
secure := ""
if useTLS {
secure = "s"
}
return fmt.Sprintf("http%s://localhost:%d", secure, port), func() {
_ = srv.Shutdown(context.Background())
wg.Wait()
}
}

View file

@ -84,7 +84,7 @@ type DeviceListUpdater struct {
db DeviceListUpdaterDatabase
api DeviceListUpdaterAPI
producer KeyChangeProducer
fedClient fedsenderapi.FederationClient
fedClient fedsenderapi.KeyserverFederationAPI
workerChans []chan gomatrixserverlib.ServerName
// When device lists are stale for a user, they get inserted into this map with a channel which `Update` will
@ -127,7 +127,7 @@ type KeyChangeProducer interface {
// NewDeviceListUpdater creates a new updater which fetches fresh device lists when they go stale.
func NewDeviceListUpdater(
db DeviceListUpdaterDatabase, api DeviceListUpdaterAPI, producer KeyChangeProducer,
fedClient fedsenderapi.FederationClient, numWorkers int,
fedClient fedsenderapi.KeyserverFederationAPI, numWorkers int,
) *DeviceListUpdater {
return &DeviceListUpdater{
userIDToMutex: make(map[string]*sync.Mutex),

View file

@ -37,7 +37,7 @@ import (
type KeyInternalAPI struct {
DB storage.Database
ThisServer gomatrixserverlib.ServerName
FedClient fedsenderapi.FederationClient
FedClient fedsenderapi.KeyserverFederationAPI
UserAPI userapi.KeyserverUserAPI
Producer *producers.KeyChange
Updater *DeviceListUpdater

View file

@ -37,7 +37,7 @@ func AddInternalRoutes(router *mux.Router, intAPI api.KeyInternalAPI) {
// NewInternalAPI returns a concerete implementation of the internal API. Callers
// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes.
func NewInternalAPI(
base *base.BaseDendrite, cfg *config.KeyServer, fedClient fedsenderapi.FederationClient,
base *base.BaseDendrite, cfg *config.KeyServer, fedClient fedsenderapi.KeyserverFederationAPI,
) api.KeyInternalAPI {
js, _ := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream)

View file

@ -183,6 +183,7 @@ type FederationRoomserverAPI interface {
QueryMissingEvents(ctx context.Context, req *QueryMissingEventsRequest, res *QueryMissingEventsResponse) error
// Query whether a server is allowed to see an event
QueryServerAllowedToSeeEvent(ctx context.Context, req *QueryServerAllowedToSeeEventRequest, res *QueryServerAllowedToSeeEventResponse) error
QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error
PerformInboundPeek(ctx context.Context, req *PerformInboundPeekRequest, res *PerformInboundPeekResponse) error
PerformInvite(ctx context.Context, req *PerformInviteRequest, res *PerformInviteResponse) error
// Query a given amount (or less) of events prior to a given set of events.

View file

@ -12,7 +12,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
)
@ -22,7 +22,7 @@ var jc *nats.Conn
func TestMain(m *testing.M) {
var b *base.BaseDendrite
b, js, jc = test.Base(nil)
b, js, jc = testrig.Base(nil)
code := m.Run()
b.ShutdownDendrite()
b.WaitForComponentsToFinish()

View file

@ -19,8 +19,8 @@ import (
"encoding/json"
"testing"
"github.com/matrix-org/dendrite/internal/test"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/gomatrixserverlib"
)

View file

@ -264,11 +264,11 @@ func (s *eventStatements) BulkSelectStateEventByNID(
ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
stateKeyTuples []types.StateKeyTuple,
) ([]types.StateEntry, error) {
tuples := stateKeyTupleSorter(stateKeyTuples)
tuples := types.StateKeyTupleSorter(stateKeyTuples)
sort.Sort(tuples)
eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays()
eventTypeNIDArray, eventStateKeyNIDArray := tuples.TypesAndStateKeysAsArrays()
stmt := sqlutil.TxStmt(txn, s.bulkSelectStateEventByNIDStmt)
rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs), eventTypeNIDArray, eventStateKeyNIDArray)
rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs), pq.Int64Array(eventTypeNIDArray), pq.Int64Array(eventStateKeyNIDArray))
if err != nil {
return nil, err
}

View file

@ -61,12 +61,12 @@ type roomAliasesStatements struct {
deleteRoomAliasStmt *sql.Stmt
}
func createRoomAliasesTable(db *sql.DB) error {
func CreateRoomAliasesTable(db *sql.DB) error {
_, err := db.Exec(roomAliasesSchema)
return err
}
func prepareRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) {
func PrepareRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) {
s := &roomAliasesStatements{}
return s, sqlutil.StatementList{
@ -108,8 +108,8 @@ func (s *roomAliasesStatements) SelectAliasesFromRoomID(
defer internal.CloseAndLogIfError(ctx, rows, "selectAliasesFromRoomID: rows.close() failed")
var aliases []string
for rows.Next() {
var alias string
for rows.Next() {
if err = rows.Scan(&alias); err != nil {
return nil, err
}

View file

@ -95,12 +95,12 @@ type roomStatements struct {
bulkSelectRoomNIDsStmt *sql.Stmt
}
func createRoomsTable(db *sql.DB) error {
func CreateRoomsTable(db *sql.DB) error {
_, err := db.Exec(roomsSchema)
return err
}
func prepareRoomsTable(db *sql.DB) (tables.Rooms, error) {
func PrepareRoomsTable(db *sql.DB) (tables.Rooms, error) {
s := &roomStatements{}
return s, sqlutil.StatementList{
@ -117,7 +117,7 @@ func prepareRoomsTable(db *sql.DB) (tables.Rooms, error) {
}.Prepare(db)
}
func (s *roomStatements) SelectRoomIDs(ctx context.Context, txn *sql.Tx) ([]string, error) {
func (s *roomStatements) SelectRoomIDsWithEvents(ctx context.Context, txn *sql.Tx) ([]string, error) {
stmt := sqlutil.TxStmt(txn, s.selectRoomIDsStmt)
rows, err := stmt.QueryContext(ctx)
if err != nil {
@ -125,8 +125,8 @@ func (s *roomStatements) SelectRoomIDs(ctx context.Context, txn *sql.Tx) ([]stri
}
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomIDsStmt: rows.close() failed")
var roomIDs []string
for rows.Next() {
var roomID string
for rows.Next() {
if err = rows.Scan(&roomID); err != nil {
return nil, err
}
@ -231,9 +231,9 @@ func (s *roomStatements) SelectRoomVersionsForRoomNIDs(
}
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomVersionsForRoomNIDsStmt: rows.close() failed")
result := make(map[types.RoomNID]gomatrixserverlib.RoomVersion)
for rows.Next() {
var roomNID types.RoomNID
var roomVersion gomatrixserverlib.RoomVersion
for rows.Next() {
if err = rows.Scan(&roomNID, &roomVersion); err != nil {
return nil, err
}
@ -254,8 +254,8 @@ func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, txn *sql.Tx, roo
}
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomIDsStmt: rows.close() failed")
var roomIDs []string
for rows.Next() {
var roomID string
for rows.Next() {
if err = rows.Scan(&roomID); err != nil {
return nil, err
}
@ -276,8 +276,8 @@ func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, txn *sql.Tx, ro
}
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomNIDsStmt: rows.close() failed")
var roomNIDs []types.RoomNID
for rows.Next() {
var roomNID types.RoomNID
for rows.Next() {
if err = rows.Scan(&roomNID); err != nil {
return nil, err
}

View file

@ -19,7 +19,6 @@ import (
"context"
"database/sql"
"fmt"
"sort"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/internal"
@ -71,12 +70,12 @@ type stateBlockStatements struct {
bulkSelectStateBlockEntriesStmt *sql.Stmt
}
func createStateBlockTable(db *sql.DB) error {
func CreateStateBlockTable(db *sql.DB) error {
_, err := db.Exec(stateDataSchema)
return err
}
func prepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) {
func PrepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) {
s := &stateBlockStatements{}
return s, sqlutil.StatementList{
@ -90,9 +89,9 @@ func (s *stateBlockStatements) BulkInsertStateData(
entries types.StateEntries,
) (id types.StateBlockNID, err error) {
entries = entries[:util.SortAndUnique(entries)]
var nids types.EventNIDs
for _, e := range entries {
nids = append(nids, e.EventNID)
nids := make(types.EventNIDs, entries.Len())
for i := range entries {
nids[i] = entries[i].EventNID
}
stmt := sqlutil.TxStmt(txn, s.insertStateDataStmt)
err = stmt.QueryRowContext(
@ -113,15 +112,15 @@ func (s *stateBlockStatements) BulkSelectStateBlockEntries(
results := make([][]types.EventNID, len(stateBlockNIDs))
i := 0
for ; rows.Next(); i++ {
var stateBlockNID types.StateBlockNID
var result pq.Int64Array
for ; rows.Next(); i++ {
if err = rows.Scan(&stateBlockNID, &result); err != nil {
return nil, err
}
r := []types.EventNID{}
for _, e := range result {
r = append(r, types.EventNID(e))
r := make([]types.EventNID, len(result))
for x := range result {
r[x] = types.EventNID(result[x])
}
results[i] = r
}
@ -141,35 +140,3 @@ func stateBlockNIDsAsArray(stateBlockNIDs []types.StateBlockNID) pq.Int64Array {
}
return pq.Int64Array(nids)
}
type stateKeyTupleSorter []types.StateKeyTuple
func (s stateKeyTupleSorter) Len() int { return len(s) }
func (s stateKeyTupleSorter) Less(i, j int) bool { return s[i].LessThan(s[j]) }
func (s stateKeyTupleSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
// Check whether a tuple is in the list. Assumes that the list is sorted.
func (s stateKeyTupleSorter) contains(value types.StateKeyTuple) bool {
i := sort.Search(len(s), func(i int) bool { return !s[i].LessThan(value) })
return i < len(s) && s[i] == value
}
// List the unique eventTypeNIDs and eventStateKeyNIDs.
// Assumes that the list is sorted.
func (s stateKeyTupleSorter) typesAndStateKeysAsArrays() (eventTypeNIDs pq.Int64Array, eventStateKeyNIDs pq.Int64Array) {
eventTypeNIDs = make(pq.Int64Array, len(s))
eventStateKeyNIDs = make(pq.Int64Array, len(s))
for i := range s {
eventTypeNIDs[i] = int64(s[i].EventTypeNID)
eventStateKeyNIDs[i] = int64(s[i].EventStateKeyNID)
}
eventTypeNIDs = eventTypeNIDs[:util.SortAndUnique(int64Sorter(eventTypeNIDs))]
eventStateKeyNIDs = eventStateKeyNIDs[:util.SortAndUnique(int64Sorter(eventStateKeyNIDs))]
return
}
type int64Sorter []int64
func (s int64Sorter) Len() int { return len(s) }
func (s int64Sorter) Less(i, j int) bool { return s[i] < s[j] }
func (s int64Sorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }

View file

@ -1,86 +0,0 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"sort"
"testing"
"github.com/matrix-org/dendrite/roomserver/types"
)
func TestStateKeyTupleSorter(t *testing.T) {
input := stateKeyTupleSorter{
{EventTypeNID: 1, EventStateKeyNID: 2},
{EventTypeNID: 1, EventStateKeyNID: 4},
{EventTypeNID: 2, EventStateKeyNID: 2},
{EventTypeNID: 1, EventStateKeyNID: 1},
}
want := []types.StateKeyTuple{
{EventTypeNID: 1, EventStateKeyNID: 1},
{EventTypeNID: 1, EventStateKeyNID: 2},
{EventTypeNID: 1, EventStateKeyNID: 4},
{EventTypeNID: 2, EventStateKeyNID: 2},
}
doNotWant := []types.StateKeyTuple{
{EventTypeNID: 0, EventStateKeyNID: 0},
{EventTypeNID: 1, EventStateKeyNID: 3},
{EventTypeNID: 2, EventStateKeyNID: 1},
{EventTypeNID: 3, EventStateKeyNID: 1},
}
wantTypeNIDs := []int64{1, 2}
wantStateKeyNIDs := []int64{1, 2, 4}
// Sort the input and check it's in the right order.
sort.Sort(input)
gotTypeNIDs, gotStateKeyNIDs := input.typesAndStateKeysAsArrays()
for i := range want {
if input[i] != want[i] {
t.Errorf("Wanted %#v at index %d got %#v", want[i], i, input[i])
}
if !input.contains(want[i]) {
t.Errorf("Wanted %#v.contains(%#v) to be true but got false", input, want[i])
}
}
for i := range doNotWant {
if input.contains(doNotWant[i]) {
t.Errorf("Wanted %#v.contains(%#v) to be false but got true", input, doNotWant[i])
}
}
if len(wantTypeNIDs) != len(gotTypeNIDs) {
t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
}
for i := range wantTypeNIDs {
if wantTypeNIDs[i] != gotTypeNIDs[i] {
t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
}
}
if len(wantStateKeyNIDs) != len(gotStateKeyNIDs) {
t.Fatalf("Wanted state key NIDs %#v got %#v", wantStateKeyNIDs, gotStateKeyNIDs)
}
for i := range wantStateKeyNIDs {
if wantStateKeyNIDs[i] != gotStateKeyNIDs[i] {
t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
}
}
}

View file

@ -77,12 +77,12 @@ type stateSnapshotStatements struct {
bulkSelectStateBlockNIDsStmt *sql.Stmt
}
func createStateSnapshotTable(db *sql.DB) error {
func CreateStateSnapshotTable(db *sql.DB) error {
_, err := db.Exec(stateSnapshotSchema)
return err
}
func prepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) {
func PrepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) {
s := &stateSnapshotStatements{}
return s, sqlutil.StatementList{
@ -95,12 +95,10 @@ func (s *stateSnapshotStatements) InsertState(
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, nids types.StateBlockNIDs,
) (stateNID types.StateSnapshotNID, err error) {
nids = nids[:util.SortAndUnique(nids)]
var id int64
err = sqlutil.TxStmt(txn, s.insertStateStmt).QueryRowContext(ctx, nids.Hash(), int64(roomNID), stateBlockNIDsAsArray(nids)).Scan(&id)
err = sqlutil.TxStmt(txn, s.insertStateStmt).QueryRowContext(ctx, nids.Hash(), int64(roomNID), stateBlockNIDsAsArray(nids)).Scan(&stateNID)
if err != nil {
return 0, err
}
stateNID = types.StateSnapshotNID(id)
return
}
@ -119,9 +117,9 @@ func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs(
defer rows.Close() // nolint: errcheck
results := make([]types.StateBlockNIDList, len(stateNIDs))
i := 0
var stateBlockNIDs pq.Int64Array
for ; rows.Next(); i++ {
result := &results[i]
var stateBlockNIDs pq.Int64Array
if err = rows.Scan(&result.StateSnapshotNID, &stateBlockNIDs); err != nil {
return nil, err
}

View file

@ -80,19 +80,19 @@ func (d *Database) create(db *sql.DB) error {
if err := CreateEventsTable(db); err != nil {
return err
}
if err := createRoomsTable(db); err != nil {
if err := CreateRoomsTable(db); err != nil {
return err
}
if err := createStateBlockTable(db); err != nil {
if err := CreateStateBlockTable(db); err != nil {
return err
}
if err := createStateSnapshotTable(db); err != nil {
if err := CreateStateSnapshotTable(db); err != nil {
return err
}
if err := CreatePrevEventsTable(db); err != nil {
return err
}
if err := createRoomAliasesTable(db); err != nil {
if err := CreateRoomAliasesTable(db); err != nil {
return err
}
if err := CreateInvitesTable(db); err != nil {
@ -128,15 +128,15 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room
if err != nil {
return err
}
rooms, err := prepareRoomsTable(db)
rooms, err := PrepareRoomsTable(db)
if err != nil {
return err
}
stateBlock, err := prepareStateBlockTable(db)
stateBlock, err := PrepareStateBlockTable(db)
if err != nil {
return err
}
stateSnapshot, err := prepareStateSnapshotTable(db)
stateSnapshot, err := PrepareStateSnapshotTable(db)
if err != nil {
return err
}
@ -144,7 +144,7 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room
if err != nil {
return err
}
roomAliases, err := prepareRoomAliasesTable(db)
roomAliases, err := PrepareRoomAliasesTable(db)
if err != nil {
return err
}

View file

@ -1216,7 +1216,7 @@ func (d *Database) GetKnownUsers(ctx context.Context, userID, searchString strin
// GetKnownRooms returns a list of all rooms we know about.
func (d *Database) GetKnownRooms(ctx context.Context) ([]string, error) {
return d.RoomsTable.SelectRoomIDs(ctx, nil)
return d.RoomsTable.SelectRoomIDsWithEvents(ctx, nil)
}
// ForgetRoom sets a users room to forgotten

View file

@ -247,9 +247,9 @@ func (s *eventStatements) BulkSelectStateEventByNID(
ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
stateKeyTuples []types.StateKeyTuple,
) ([]types.StateEntry, error) {
tuples := stateKeyTupleSorter(stateKeyTuples)
tuples := types.StateKeyTupleSorter(stateKeyTuples)
sort.Sort(tuples)
eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays()
eventTypeNIDArray, eventStateKeyNIDArray := tuples.TypesAndStateKeysAsArrays()
params := make([]interface{}, 0, len(eventNIDs)+len(eventTypeNIDArray)+len(eventStateKeyNIDArray))
selectOrig := strings.Replace(bulkSelectStateEventByNIDSQL, "($1)", sqlutil.QueryVariadic(len(eventNIDs)), 1)
for _, v := range eventNIDs {

View file

@ -63,12 +63,12 @@ type roomAliasesStatements struct {
deleteRoomAliasStmt *sql.Stmt
}
func createRoomAliasesTable(db *sql.DB) error {
func CreateRoomAliasesTable(db *sql.DB) error {
_, err := db.Exec(roomAliasesSchema)
return err
}
func prepareRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) {
func PrepareRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) {
s := &roomAliasesStatements{
db: db,
}
@ -113,8 +113,8 @@ func (s *roomAliasesStatements) SelectAliasesFromRoomID(
defer internal.CloseAndLogIfError(ctx, rows, "selectAliasesFromRoomID: rows.close() failed")
for rows.Next() {
var alias string
for rows.Next() {
if err = rows.Scan(&alias); err != nil {
return
}

View file

@ -86,12 +86,12 @@ type roomStatements struct {
selectRoomIDsStmt *sql.Stmt
}
func createRoomsTable(db *sql.DB) error {
func CreateRoomsTable(db *sql.DB) error {
_, err := db.Exec(roomsSchema)
return err
}
func prepareRoomsTable(db *sql.DB) (tables.Rooms, error) {
func PrepareRoomsTable(db *sql.DB) (tables.Rooms, error) {
s := &roomStatements{
db: db,
}
@ -108,7 +108,7 @@ func prepareRoomsTable(db *sql.DB) (tables.Rooms, error) {
}.Prepare(db)
}
func (s *roomStatements) SelectRoomIDs(ctx context.Context, txn *sql.Tx) ([]string, error) {
func (s *roomStatements) SelectRoomIDsWithEvents(ctx context.Context, txn *sql.Tx) ([]string, error) {
stmt := sqlutil.TxStmt(txn, s.selectRoomIDsStmt)
rows, err := stmt.QueryContext(ctx)
if err != nil {
@ -116,8 +116,8 @@ func (s *roomStatements) SelectRoomIDs(ctx context.Context, txn *sql.Tx) ([]stri
}
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomIDsStmt: rows.close() failed")
var roomIDs []string
for rows.Next() {
var roomID string
for rows.Next() {
if err = rows.Scan(&roomID); err != nil {
return nil, err
}
@ -241,9 +241,9 @@ func (s *roomStatements) SelectRoomVersionsForRoomNIDs(
}
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomVersionsForRoomNIDsStmt: rows.close() failed")
result := make(map[types.RoomNID]gomatrixserverlib.RoomVersion)
for rows.Next() {
var roomNID types.RoomNID
var roomVersion gomatrixserverlib.RoomVersion
for rows.Next() {
if err = rows.Scan(&roomNID, &roomVersion); err != nil {
return nil, err
}
@ -270,8 +270,8 @@ func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, txn *sql.Tx, roo
}
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomIDsStmt: rows.close() failed")
var roomIDs []string
for rows.Next() {
var roomID string
for rows.Next() {
if err = rows.Scan(&roomID); err != nil {
return nil, err
}
@ -298,8 +298,8 @@ func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, txn *sql.Tx, ro
}
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomNIDsStmt: rows.close() failed")
var roomNIDs []types.RoomNID
for rows.Next() {
var roomNID types.RoomNID
for rows.Next() {
if err = rows.Scan(&roomNID); err != nil {
return nil, err
}

View file

@ -20,7 +20,6 @@ import (
"database/sql"
"encoding/json"
"fmt"
"sort"
"strings"
"github.com/matrix-org/dendrite/internal"
@ -64,12 +63,12 @@ type stateBlockStatements struct {
bulkSelectStateBlockEntriesStmt *sql.Stmt
}
func createStateBlockTable(db *sql.DB) error {
func CreateStateBlockTable(db *sql.DB) error {
_, err := db.Exec(stateDataSchema)
return err
}
func prepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) {
func PrepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) {
s := &stateBlockStatements{
db: db,
}
@ -85,9 +84,9 @@ func (s *stateBlockStatements) BulkInsertStateData(
entries types.StateEntries,
) (id types.StateBlockNID, err error) {
entries = entries[:util.SortAndUnique(entries)]
nids := types.EventNIDs{} // zero slice to not store 'null' in the DB
for _, e := range entries {
nids = append(nids, e.EventNID)
nids := make(types.EventNIDs, entries.Len())
for i := range entries {
nids[i] = entries[i].EventNID
}
js, err := json.Marshal(nids)
if err != nil {
@ -122,13 +121,13 @@ func (s *stateBlockStatements) BulkSelectStateBlockEntries(
results := make([][]types.EventNID, len(stateBlockNIDs))
i := 0
for ; rows.Next(); i++ {
var stateBlockNID types.StateBlockNID
var result json.RawMessage
for ; rows.Next(); i++ {
if err = rows.Scan(&stateBlockNID, &result); err != nil {
return nil, err
}
r := []types.EventNID{}
var r []types.EventNID
if err = json.Unmarshal(result, &r); err != nil {
return nil, fmt.Errorf("json.Unmarshal: %w", err)
}
@ -142,35 +141,3 @@ func (s *stateBlockStatements) BulkSelectStateBlockEntries(
}
return results, err
}
type stateKeyTupleSorter []types.StateKeyTuple
func (s stateKeyTupleSorter) Len() int { return len(s) }
func (s stateKeyTupleSorter) Less(i, j int) bool { return s[i].LessThan(s[j]) }
func (s stateKeyTupleSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
// Check whether a tuple is in the list. Assumes that the list is sorted.
func (s stateKeyTupleSorter) contains(value types.StateKeyTuple) bool {
i := sort.Search(len(s), func(i int) bool { return !s[i].LessThan(value) })
return i < len(s) && s[i] == value
}
// List the unique eventTypeNIDs and eventStateKeyNIDs.
// Assumes that the list is sorted.
func (s stateKeyTupleSorter) typesAndStateKeysAsArrays() (eventTypeNIDs []int64, eventStateKeyNIDs []int64) {
eventTypeNIDs = make([]int64, len(s))
eventStateKeyNIDs = make([]int64, len(s))
for i := range s {
eventTypeNIDs[i] = int64(s[i].EventTypeNID)
eventStateKeyNIDs[i] = int64(s[i].EventStateKeyNID)
}
eventTypeNIDs = eventTypeNIDs[:util.SortAndUnique(int64Sorter(eventTypeNIDs))]
eventStateKeyNIDs = eventStateKeyNIDs[:util.SortAndUnique(int64Sorter(eventStateKeyNIDs))]
return
}
type int64Sorter []int64
func (s int64Sorter) Len() int { return len(s) }
func (s int64Sorter) Less(i, j int) bool { return s[i] < s[j] }
func (s int64Sorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }

View file

@ -1,86 +0,0 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"sort"
"testing"
"github.com/matrix-org/dendrite/roomserver/types"
)
func TestStateKeyTupleSorter(t *testing.T) {
input := stateKeyTupleSorter{
{EventTypeNID: 1, EventStateKeyNID: 2},
{EventTypeNID: 1, EventStateKeyNID: 4},
{EventTypeNID: 2, EventStateKeyNID: 2},
{EventTypeNID: 1, EventStateKeyNID: 1},
}
want := []types.StateKeyTuple{
{EventTypeNID: 1, EventStateKeyNID: 1},
{EventTypeNID: 1, EventStateKeyNID: 2},
{EventTypeNID: 1, EventStateKeyNID: 4},
{EventTypeNID: 2, EventStateKeyNID: 2},
}
doNotWant := []types.StateKeyTuple{
{EventTypeNID: 0, EventStateKeyNID: 0},
{EventTypeNID: 1, EventStateKeyNID: 3},
{EventTypeNID: 2, EventStateKeyNID: 1},
{EventTypeNID: 3, EventStateKeyNID: 1},
}
wantTypeNIDs := []int64{1, 2}
wantStateKeyNIDs := []int64{1, 2, 4}
// Sort the input and check it's in the right order.
sort.Sort(input)
gotTypeNIDs, gotStateKeyNIDs := input.typesAndStateKeysAsArrays()
for i := range want {
if input[i] != want[i] {
t.Errorf("Wanted %#v at index %d got %#v", want[i], i, input[i])
}
if !input.contains(want[i]) {
t.Errorf("Wanted %#v.contains(%#v) to be true but got false", input, want[i])
}
}
for i := range doNotWant {
if input.contains(doNotWant[i]) {
t.Errorf("Wanted %#v.contains(%#v) to be false but got true", input, doNotWant[i])
}
}
if len(wantTypeNIDs) != len(gotTypeNIDs) {
t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
}
for i := range wantTypeNIDs {
if wantTypeNIDs[i] != gotTypeNIDs[i] {
t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
}
}
if len(wantStateKeyNIDs) != len(gotStateKeyNIDs) {
t.Fatalf("Wanted state key NIDs %#v got %#v", wantStateKeyNIDs, gotStateKeyNIDs)
}
for i := range wantStateKeyNIDs {
if wantStateKeyNIDs[i] != gotStateKeyNIDs[i] {
t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
}
}
}

View file

@ -68,12 +68,12 @@ type stateSnapshotStatements struct {
bulkSelectStateBlockNIDsStmt *sql.Stmt
}
func createStateSnapshotTable(db *sql.DB) error {
func CreateStateSnapshotTable(db *sql.DB) error {
_, err := db.Exec(stateSnapshotSchema)
return err
}
func prepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) {
func PrepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) {
s := &stateSnapshotStatements{
db: db,
}
@ -96,12 +96,10 @@ func (s *stateSnapshotStatements) InsertState(
return
}
insertStmt := sqlutil.TxStmt(txn, s.insertStateStmt)
var id int64
err = insertStmt.QueryRowContext(ctx, stateBlockNIDs.Hash(), int64(roomNID), string(stateBlockNIDsJSON)).Scan(&id)
err = insertStmt.QueryRowContext(ctx, stateBlockNIDs.Hash(), int64(roomNID), string(stateBlockNIDsJSON)).Scan(&stateNID)
if err != nil {
return 0, err
}
stateNID = types.StateSnapshotNID(id)
return
}
@ -127,9 +125,9 @@ func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs(
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateBlockNIDs: rows.close() failed")
results := make([]types.StateBlockNIDList, len(stateNIDs))
i := 0
var stateBlockNIDsJSON string
for ; rows.Next(); i++ {
result := &results[i]
var stateBlockNIDsJSON string
if err := rows.Scan(&result.StateSnapshotNID, &stateBlockNIDsJSON); err != nil {
return nil, err
}

View file

@ -89,19 +89,19 @@ func (d *Database) create(db *sql.DB) error {
if err := CreateEventsTable(db); err != nil {
return err
}
if err := createRoomsTable(db); err != nil {
if err := CreateRoomsTable(db); err != nil {
return err
}
if err := createStateBlockTable(db); err != nil {
if err := CreateStateBlockTable(db); err != nil {
return err
}
if err := createStateSnapshotTable(db); err != nil {
if err := CreateStateSnapshotTable(db); err != nil {
return err
}
if err := CreatePrevEventsTable(db); err != nil {
return err
}
if err := createRoomAliasesTable(db); err != nil {
if err := CreateRoomAliasesTable(db); err != nil {
return err
}
if err := CreateInvitesTable(db); err != nil {
@ -137,15 +137,15 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room
if err != nil {
return err
}
rooms, err := prepareRoomsTable(db)
rooms, err := PrepareRoomsTable(db)
if err != nil {
return err
}
stateBlock, err := prepareStateBlockTable(db)
stateBlock, err := PrepareStateBlockTable(db)
if err != nil {
return err
}
stateSnapshot, err := prepareStateSnapshotTable(db)
stateSnapshot, err := PrepareStateSnapshotTable(db)
if err != nil {
return err
}
@ -153,7 +153,7 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room
if err != nil {
return err
}
roomAliases, err := prepareRoomAliasesTable(db)
roomAliases, err := PrepareRoomAliasesTable(db)
if err != nil {
return err
}

View file

@ -39,7 +39,7 @@ func mustCreateEventsTable(t *testing.T, dbType test.DBType) (tables.Events, fun
}
func Test_EventsTable(t *testing.T) {
alice := test.NewUser()
alice := test.NewUser(t)
room := test.NewRoom(t, alice)
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {

View file

@ -72,7 +72,7 @@ type Rooms interface {
UpdateLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID, stateSnapshotNID types.StateSnapshotNID) error
SelectRoomVersionsForRoomNIDs(ctx context.Context, txn *sql.Tx, roomNID []types.RoomNID) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error)
SelectRoomInfo(ctx context.Context, txn *sql.Tx, roomID string) (*types.RoomInfo, error)
SelectRoomIDs(ctx context.Context, txn *sql.Tx) ([]string, error)
SelectRoomIDsWithEvents(ctx context.Context, txn *sql.Tx) ([]string, error)
BulkSelectRoomIDs(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) ([]string, error)
BulkSelectRoomNIDs(ctx context.Context, txn *sql.Tx, roomIDs []string) ([]types.RoomNID, error)
}

View file

@ -38,7 +38,7 @@ func mustCreatePreviousEventsTable(t *testing.T, dbType test.DBType) (tab tables
func TestPreviousEventsTable(t *testing.T) {
ctx := context.Background()
alice := test.NewUser()
alice := test.NewUser(t)
room := test.NewRoom(t, alice)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
tab, close := mustCreatePreviousEventsTable(t, dbType)

View file

@ -38,7 +38,7 @@ func mustCreatePublishedTable(t *testing.T, dbType test.DBType) (tab tables.Publ
func TestPublishedTable(t *testing.T) {
ctx := context.Background()
alice := test.NewUser()
alice := test.NewUser(t)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
tab, close := mustCreatePublishedTable(t, dbType)

View file

@ -0,0 +1,96 @@
package tables_test
import (
"context"
"testing"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/postgres"
"github.com/matrix-org/dendrite/roomserver/storage/sqlite3"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test"
"github.com/stretchr/testify/assert"
)
func mustCreateRoomAliasesTable(t *testing.T, dbType test.DBType) (tab tables.RoomAliases, close func()) {
t.Helper()
connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := sqlutil.Open(&config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
}, sqlutil.NewExclusiveWriter())
assert.NoError(t, err)
switch dbType {
case test.DBTypePostgres:
err = postgres.CreateRoomAliasesTable(db)
assert.NoError(t, err)
tab, err = postgres.PrepareRoomAliasesTable(db)
case test.DBTypeSQLite:
err = sqlite3.CreateRoomAliasesTable(db)
assert.NoError(t, err)
tab, err = sqlite3.PrepareRoomAliasesTable(db)
}
assert.NoError(t, err)
return tab, close
}
func TestRoomAliasesTable(t *testing.T) {
alice := test.NewUser(t)
room := test.NewRoom(t, alice)
room2 := test.NewRoom(t, alice)
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
tab, close := mustCreateRoomAliasesTable(t, dbType)
defer close()
alias, alias2, alias3 := "#alias:localhost", "#alias2:localhost", "#alias3:localhost"
// insert aliases
err := tab.InsertRoomAlias(ctx, nil, alias, room.ID, alice.ID)
assert.NoError(t, err)
err = tab.InsertRoomAlias(ctx, nil, alias2, room.ID, alice.ID)
assert.NoError(t, err)
err = tab.InsertRoomAlias(ctx, nil, alias3, room2.ID, alice.ID)
assert.NoError(t, err)
// verify we can get the roomID for the alias
roomID, err := tab.SelectRoomIDFromAlias(ctx, nil, alias)
assert.NoError(t, err)
assert.Equal(t, room.ID, roomID)
// .. and the creator
creator, err := tab.SelectCreatorIDFromAlias(ctx, nil, alias)
assert.NoError(t, err)
assert.Equal(t, alice.ID, creator)
creator, err = tab.SelectCreatorIDFromAlias(ctx, nil, "#doesntexist:localhost")
assert.NoError(t, err)
assert.Equal(t, "", creator)
roomID, err = tab.SelectRoomIDFromAlias(ctx, nil, "#doesntexist:localhost")
assert.NoError(t, err)
assert.Equal(t, "", roomID)
// get all aliases for a room
aliases, err := tab.SelectAliasesFromRoomID(ctx, nil, room.ID)
assert.NoError(t, err)
assert.Equal(t, []string{alias, alias2}, aliases)
// delete an alias and verify it's deleted
err = tab.DeleteRoomAlias(ctx, nil, alias2)
assert.NoError(t, err)
aliases, err = tab.SelectAliasesFromRoomID(ctx, nil, room.ID)
assert.NoError(t, err)
assert.Equal(t, []string{alias}, aliases)
// deleting the same alias should be a no-op
err = tab.DeleteRoomAlias(ctx, nil, alias2)
assert.NoError(t, err)
// Delete non-existent alias should be a no-op
err = tab.DeleteRoomAlias(ctx, nil, "#doesntexist:localhost")
assert.NoError(t, err)
})
}

View file

@ -0,0 +1,128 @@
package tables_test
import (
"context"
"testing"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/postgres"
"github.com/matrix-org/dendrite/roomserver/storage/sqlite3"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/util"
"github.com/stretchr/testify/assert"
)
func mustCreateRoomsTable(t *testing.T, dbType test.DBType) (tab tables.Rooms, close func()) {
t.Helper()
connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := sqlutil.Open(&config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
}, sqlutil.NewExclusiveWriter())
assert.NoError(t, err)
switch dbType {
case test.DBTypePostgres:
err = postgres.CreateRoomsTable(db)
assert.NoError(t, err)
tab, err = postgres.PrepareRoomsTable(db)
case test.DBTypeSQLite:
err = sqlite3.CreateRoomsTable(db)
assert.NoError(t, err)
tab, err = sqlite3.PrepareRoomsTable(db)
}
assert.NoError(t, err)
return tab, close
}
func TestRoomsTable(t *testing.T) {
alice := test.NewUser(t)
room := test.NewRoom(t, alice)
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
tab, close := mustCreateRoomsTable(t, dbType)
defer close()
wantRoomNID, err := tab.InsertRoomNID(ctx, nil, room.ID, room.Version)
assert.NoError(t, err)
// Create dummy room
_, err = tab.InsertRoomNID(ctx, nil, util.RandomString(16), room.Version)
assert.NoError(t, err)
gotRoomNID, err := tab.SelectRoomNID(ctx, nil, room.ID)
assert.NoError(t, err)
assert.Equal(t, wantRoomNID, gotRoomNID)
// Ensure non existent roomNID errors
roomNID, err := tab.SelectRoomNID(ctx, nil, "!doesnotexist:localhost")
assert.Error(t, err)
assert.Equal(t, types.RoomNID(0), roomNID)
roomInfo, err := tab.SelectRoomInfo(ctx, nil, room.ID)
assert.NoError(t, err)
assert.Equal(t, &types.RoomInfo{
RoomNID: wantRoomNID,
RoomVersion: room.Version,
StateSnapshotNID: 0,
IsStub: true, // there are no latestEventNIDs
}, roomInfo)
roomInfo, err = tab.SelectRoomInfo(ctx, nil, "!doesnotexist:localhost")
assert.NoError(t, err)
assert.Nil(t, roomInfo)
// There are no rooms with latestEventNIDs yet
roomIDs, err := tab.SelectRoomIDsWithEvents(ctx, nil)
assert.NoError(t, err)
assert.Equal(t, 0, len(roomIDs))
roomVersions, err := tab.SelectRoomVersionsForRoomNIDs(ctx, nil, []types.RoomNID{wantRoomNID, 1337})
assert.NoError(t, err)
assert.Equal(t, roomVersions[wantRoomNID], room.Version)
// Room does not exist
_, ok := roomVersions[1337]
assert.False(t, ok)
roomIDs, err = tab.BulkSelectRoomIDs(ctx, nil, []types.RoomNID{wantRoomNID, 1337})
assert.NoError(t, err)
assert.Equal(t, []string{room.ID}, roomIDs)
roomNIDs, err := tab.BulkSelectRoomNIDs(ctx, nil, []string{room.ID, "!doesnotexist:localhost"})
assert.NoError(t, err)
assert.Equal(t, []types.RoomNID{wantRoomNID}, roomNIDs)
wantEventNIDs := []types.EventNID{1, 2, 3}
lastEventSentNID := types.EventNID(3)
stateSnapshotNID := types.StateSnapshotNID(1)
// make the room "usable"
err = tab.UpdateLatestEventNIDs(ctx, nil, wantRoomNID, wantEventNIDs, lastEventSentNID, stateSnapshotNID)
assert.NoError(t, err)
roomInfo, err = tab.SelectRoomInfo(ctx, nil, room.ID)
assert.NoError(t, err)
assert.Equal(t, &types.RoomInfo{
RoomNID: wantRoomNID,
RoomVersion: room.Version,
StateSnapshotNID: 1,
IsStub: false,
}, roomInfo)
eventNIDs, snapshotNID, err := tab.SelectLatestEventNIDs(ctx, nil, wantRoomNID)
assert.NoError(t, err)
assert.Equal(t, wantEventNIDs, eventNIDs)
assert.Equal(t, types.StateSnapshotNID(1), snapshotNID)
// Again, doesn't exist
_, _, err = tab.SelectLatestEventNIDs(ctx, nil, 1337)
assert.Error(t, err)
eventNIDs, eventNID, snapshotNID, err := tab.SelectLatestEventsNIDsForUpdate(ctx, nil, wantRoomNID)
assert.NoError(t, err)
assert.Equal(t, wantEventNIDs, eventNIDs)
assert.Equal(t, types.EventNID(3), eventNID)
assert.Equal(t, types.StateSnapshotNID(1), snapshotNID)
})
}

View file

@ -0,0 +1,92 @@
package tables_test
import (
"context"
"testing"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/postgres"
"github.com/matrix-org/dendrite/roomserver/storage/sqlite3"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test"
"github.com/stretchr/testify/assert"
)
func mustCreateStateBlockTable(t *testing.T, dbType test.DBType) (tab tables.StateBlock, close func()) {
t.Helper()
connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := sqlutil.Open(&config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
}, sqlutil.NewExclusiveWriter())
assert.NoError(t, err)
switch dbType {
case test.DBTypePostgres:
err = postgres.CreateStateBlockTable(db)
assert.NoError(t, err)
tab, err = postgres.PrepareStateBlockTable(db)
case test.DBTypeSQLite:
err = sqlite3.CreateStateBlockTable(db)
assert.NoError(t, err)
tab, err = sqlite3.PrepareStateBlockTable(db)
}
assert.NoError(t, err)
return tab, close
}
func TestStateBlockTable(t *testing.T) {
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
tab, close := mustCreateStateBlockTable(t, dbType)
defer close()
// generate some dummy data
var entries types.StateEntries
for i := 0; i < 100; i++ {
entry := types.StateEntry{
EventNID: types.EventNID(i),
}
entries = append(entries, entry)
}
stateBlockNID, err := tab.BulkInsertStateData(ctx, nil, entries)
assert.NoError(t, err)
assert.Equal(t, types.StateBlockNID(1), stateBlockNID)
// generate a different hash, to get a new StateBlockNID
var entries2 types.StateEntries
for i := 100; i < 300; i++ {
entry := types.StateEntry{
EventNID: types.EventNID(i),
}
entries2 = append(entries2, entry)
}
stateBlockNID, err = tab.BulkInsertStateData(ctx, nil, entries2)
assert.NoError(t, err)
assert.Equal(t, types.StateBlockNID(2), stateBlockNID)
eventNIDs, err := tab.BulkSelectStateBlockEntries(ctx, nil, types.StateBlockNIDs{1, 2})
assert.NoError(t, err)
assert.Equal(t, len(entries), len(eventNIDs[0]))
assert.Equal(t, len(entries2), len(eventNIDs[1]))
// try to get a StateBlockNID which does not exist
_, err = tab.BulkSelectStateBlockEntries(ctx, nil, types.StateBlockNIDs{5})
assert.Error(t, err)
// This should return an error, since we can only retrieve 1 StateBlock
_, err = tab.BulkSelectStateBlockEntries(ctx, nil, types.StateBlockNIDs{1, 5})
assert.Error(t, err)
for i := 0; i < 65555; i++ {
entry := types.StateEntry{
EventNID: types.EventNID(i),
}
entries2 = append(entries2, entry)
}
stateBlockNID, err = tab.BulkInsertStateData(ctx, nil, entries2)
assert.NoError(t, err)
assert.Equal(t, types.StateBlockNID(3), stateBlockNID)
})
}

View file

@ -0,0 +1,86 @@
package tables_test
import (
"context"
"testing"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/postgres"
"github.com/matrix-org/dendrite/roomserver/storage/sqlite3"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test"
"github.com/stretchr/testify/assert"
)
func mustCreateStateSnapshotTable(t *testing.T, dbType test.DBType) (tab tables.StateSnapshot, close func()) {
t.Helper()
connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := sqlutil.Open(&config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
}, sqlutil.NewExclusiveWriter())
assert.NoError(t, err)
switch dbType {
case test.DBTypePostgres:
err = postgres.CreateStateSnapshotTable(db)
assert.NoError(t, err)
tab, err = postgres.PrepareStateSnapshotTable(db)
case test.DBTypeSQLite:
err = sqlite3.CreateStateSnapshotTable(db)
assert.NoError(t, err)
tab, err = sqlite3.PrepareStateSnapshotTable(db)
}
assert.NoError(t, err)
return tab, close
}
func TestStateSnapshotTable(t *testing.T) {
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
tab, close := mustCreateStateSnapshotTable(t, dbType)
defer close()
// generate some dummy data
var stateBlockNIDs types.StateBlockNIDs
for i := 0; i < 100; i++ {
stateBlockNIDs = append(stateBlockNIDs, types.StateBlockNID(i))
}
stateNID, err := tab.InsertState(ctx, nil, 1, stateBlockNIDs)
assert.NoError(t, err)
assert.Equal(t, types.StateSnapshotNID(1), stateNID)
// verify ON CONFLICT; Note: this updates the sequence!
stateNID, err = tab.InsertState(ctx, nil, 1, stateBlockNIDs)
assert.NoError(t, err)
assert.Equal(t, types.StateSnapshotNID(1), stateNID)
// create a second snapshot
var stateBlockNIDs2 types.StateBlockNIDs
for i := 100; i < 150; i++ {
stateBlockNIDs2 = append(stateBlockNIDs2, types.StateBlockNID(i))
}
stateNID, err = tab.InsertState(ctx, nil, 1, stateBlockNIDs2)
assert.NoError(t, err)
// StateSnapshotNID is now 3, since the DO UPDATE SET statement incremented the sequence
assert.Equal(t, types.StateSnapshotNID(3), stateNID)
nidLists, err := tab.BulkSelectStateBlockNIDs(ctx, nil, []types.StateSnapshotNID{1, 3})
assert.NoError(t, err)
assert.Equal(t, stateBlockNIDs, types.StateBlockNIDs(nidLists[0].StateBlockNIDs))
assert.Equal(t, stateBlockNIDs2, types.StateBlockNIDs(nidLists[1].StateBlockNIDs))
// check we get an error if the state snapshot does not exist
_, err = tab.BulkSelectStateBlockNIDs(ctx, nil, []types.StateSnapshotNID{2})
assert.Error(t, err)
// create a second snapshot
for i := 0; i < 65555; i++ {
stateBlockNIDs2 = append(stateBlockNIDs2, types.StateBlockNID(i))
}
_, err = tab.InsertState(ctx, nil, 1, stateBlockNIDs2)
assert.NoError(t, err)
})
}

View file

@ -21,6 +21,7 @@ import (
"strings"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"golang.org/x/crypto/blake2b"
)
@ -97,6 +98,38 @@ func (a StateKeyTuple) LessThan(b StateKeyTuple) bool {
return a.EventStateKeyNID < b.EventStateKeyNID
}
type StateKeyTupleSorter []StateKeyTuple
func (s StateKeyTupleSorter) Len() int { return len(s) }
func (s StateKeyTupleSorter) Less(i, j int) bool { return s[i].LessThan(s[j]) }
func (s StateKeyTupleSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
// Check whether a tuple is in the list. Assumes that the list is sorted.
func (s StateKeyTupleSorter) contains(value StateKeyTuple) bool {
i := sort.Search(len(s), func(i int) bool { return !s[i].LessThan(value) })
return i < len(s) && s[i] == value
}
// List the unique eventTypeNIDs and eventStateKeyNIDs.
// Assumes that the list is sorted.
func (s StateKeyTupleSorter) TypesAndStateKeysAsArrays() (eventTypeNIDs []int64, eventStateKeyNIDs []int64) {
eventTypeNIDs = make([]int64, len(s))
eventStateKeyNIDs = make([]int64, len(s))
for i := range s {
eventTypeNIDs[i] = int64(s[i].EventTypeNID)
eventStateKeyNIDs[i] = int64(s[i].EventStateKeyNID)
}
eventTypeNIDs = eventTypeNIDs[:util.SortAndUnique(int64Sorter(eventTypeNIDs))]
eventStateKeyNIDs = eventStateKeyNIDs[:util.SortAndUnique(int64Sorter(eventStateKeyNIDs))]
return
}
type int64Sorter []int64
func (s int64Sorter) Len() int { return len(s) }
func (s int64Sorter) Less(i, j int) bool { return s[i] < s[j] }
func (s int64Sorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
// A StateEntry is an entry in the room state of a matrix room.
type StateEntry struct {
StateKeyTuple

View file

@ -1,6 +1,7 @@
package types
import (
"sort"
"testing"
)
@ -24,3 +25,66 @@ func TestDeduplicateStateEntries(t *testing.T) {
}
}
}
func TestStateKeyTupleSorter(t *testing.T) {
input := StateKeyTupleSorter{
{EventTypeNID: 1, EventStateKeyNID: 2},
{EventTypeNID: 1, EventStateKeyNID: 4},
{EventTypeNID: 2, EventStateKeyNID: 2},
{EventTypeNID: 1, EventStateKeyNID: 1},
}
want := []StateKeyTuple{
{EventTypeNID: 1, EventStateKeyNID: 1},
{EventTypeNID: 1, EventStateKeyNID: 2},
{EventTypeNID: 1, EventStateKeyNID: 4},
{EventTypeNID: 2, EventStateKeyNID: 2},
}
doNotWant := []StateKeyTuple{
{EventTypeNID: 0, EventStateKeyNID: 0},
{EventTypeNID: 1, EventStateKeyNID: 3},
{EventTypeNID: 2, EventStateKeyNID: 1},
{EventTypeNID: 3, EventStateKeyNID: 1},
}
wantTypeNIDs := []int64{1, 2}
wantStateKeyNIDs := []int64{1, 2, 4}
// Sort the input and check it's in the right order.
sort.Sort(input)
gotTypeNIDs, gotStateKeyNIDs := input.TypesAndStateKeysAsArrays()
for i := range want {
if input[i] != want[i] {
t.Errorf("Wanted %#v at index %d got %#v", want[i], i, input[i])
}
if !input.contains(want[i]) {
t.Errorf("Wanted %#v.contains(%#v) to be true but got false", input, want[i])
}
}
for i := range doNotWant {
if input.contains(doNotWant[i]) {
t.Errorf("Wanted %#v.contains(%#v) to be false but got true", input, doNotWant[i])
}
}
if len(wantTypeNIDs) != len(gotTypeNIDs) {
t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
}
for i := range wantTypeNIDs {
if wantTypeNIDs[i] != gotTypeNIDs[i] {
t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
}
}
if len(wantStateKeyNIDs) != len(gotStateKeyNIDs) {
t.Fatalf("Wanted state key NIDs %#v got %#v", wantStateKeyNIDs, gotStateKeyNIDs)
}
for i := range wantStateKeyNIDs {
if wantStateKeyNIDs[i] != gotStateKeyNIDs[i] {
t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
}
}
}

View file

@ -138,9 +138,12 @@ func (s *PresenceConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
presence := msg.Header.Get("presence")
timestamp := msg.Header.Get("last_active_ts")
fromSync, _ := strconv.ParseBool(msg.Header.Get("from_sync"))
logrus.Debugf("syncAPI received presence event: %+v", msg.Header)
if fromSync { // do not process local presence changes; we already did this synchronously.
return true
}
ts, err := strconv.Atoi(timestamp)
if err != nil {
return true
@ -151,15 +154,19 @@ func (s *PresenceConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
newMsg := msg.Header.Get("status_msg")
statusMsg = &newMsg
}
// OK is already checked, so no need to do it again
// already checked, so no need to check error
p, _ := types.PresenceFromString(presence)
pos, err := s.db.UpdatePresence(ctx, userID, p, statusMsg, gomatrixserverlib.Timestamp(ts), fromSync)
if err != nil {
s.EmitPresence(ctx, userID, p, statusMsg, ts, fromSync)
return true
}
func (s *PresenceConsumer) EmitPresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, ts int, fromSync bool) {
pos, err := s.db.UpdatePresence(ctx, userID, presence, statusMsg, gomatrixserverlib.Timestamp(ts), fromSync)
if err != nil {
logrus.WithError(err).WithField("user", userID).WithField("presence", presence).Warn("failed to updated presence for user")
return
}
s.stream.Advance(pos)
s.notifier.OnNewPresence(types.StreamingToken{PresencePosition: pos}, userID)
return true
}

View file

@ -55,7 +55,7 @@ func MustWriteEvents(t *testing.T, db storage.Database, events []*gomatrixserver
func TestWriteEvents(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
alice := test.NewUser()
alice := test.NewUser(t)
r := test.NewRoom(t, alice)
db, close := MustCreateDatabase(t, dbType)
defer close()
@ -68,7 +68,7 @@ func TestRecentEventsPDU(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := MustCreateDatabase(t, dbType)
defer close()
alice := test.NewUser()
alice := test.NewUser(t)
// dummy room to make sure SQL queries are filtering on room ID
MustWriteEvents(t, db, test.NewRoom(t, alice).Events())
@ -171,7 +171,7 @@ func TestGetEventsInRangeWithTopologyToken(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := MustCreateDatabase(t, dbType)
defer close()
alice := test.NewUser()
alice := test.NewUser(t)
r := test.NewRoom(t, alice)
for i := 0; i < 10; i++ {
r.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": fmt.Sprintf("hi %d", i)})

View file

@ -45,7 +45,7 @@ func newOutputRoomEventsTable(t *testing.T, dbType test.DBType) (tables.Events,
func TestOutputRoomEventsTable(t *testing.T) {
ctx := context.Background()
alice := test.NewUser()
alice := test.NewUser(t)
room := test.NewRoom(t, alice)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
tab, db, close := newOutputRoomEventsTable(t, dbType)

View file

@ -40,7 +40,7 @@ func newTopologyTable(t *testing.T, dbType test.DBType) (tables.Topology, *sql.D
func TestTopologyTable(t *testing.T) {
ctx := context.Background()
alice := test.NewUser()
alice := test.NewUser(t)
room := test.NewRoom(t, alice)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
tab, db, close := newTopologyTable(t, dbType)

View file

@ -53,19 +53,24 @@ type RequestPool struct {
streams *streams.Streams
Notifier *notifier.Notifier
producer PresencePublisher
consumer PresenceConsumer
}
type PresencePublisher interface {
SendPresence(userID string, presence types.Presence, statusMsg *string) error
}
type PresenceConsumer interface {
EmitPresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, ts int, fromSync bool)
}
// NewRequestPool makes a new RequestPool
func NewRequestPool(
db storage.Database, cfg *config.SyncAPI,
userAPI userapi.SyncUserAPI, keyAPI keyapi.SyncKeyAPI,
rsAPI roomserverAPI.SyncRoomserverAPI,
streams *streams.Streams, notifier *notifier.Notifier,
producer PresencePublisher, enableMetrics bool,
producer PresencePublisher, consumer PresenceConsumer, enableMetrics bool,
) *RequestPool {
if enableMetrics {
prometheus.MustRegister(
@ -83,6 +88,7 @@ func NewRequestPool(
streams: streams,
Notifier: notifier,
producer: producer,
consumer: consumer,
}
go rp.cleanLastSeen()
go rp.cleanPresence(db, time.Minute*5)
@ -160,6 +166,13 @@ func (rp *RequestPool) updatePresence(db storage.Presence, presence string, user
logrus.WithError(err).Error("Unable to publish presence message from sync")
return
}
// now synchronously update our view of the world. It's critical we do this before calculating
// the /sync response else we may not return presence: online immediately.
rp.consumer.EmitPresence(
context.Background(), userID, presenceID, newPresence.ClientFields.StatusMsg,
int(gomatrixserverlib.AsTimestamp(time.Now())), true,
)
}
func (rp *RequestPool) updateLastSeen(req *http.Request, device *userapi.Device) {

View file

@ -38,6 +38,12 @@ func (d dummyDB) MaxStreamPositionForPresence(ctx context.Context) (types.Stream
return 0, nil
}
type dummyConsumer struct{}
func (d dummyConsumer) EmitPresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, ts int, fromSync bool) {
}
func TestRequestPool_updatePresence(t *testing.T) {
type args struct {
presence string
@ -45,6 +51,7 @@ func TestRequestPool_updatePresence(t *testing.T) {
sleep time.Duration
}
publisher := &dummyPublisher{}
consumer := &dummyConsumer{}
syncMap := sync.Map{}
tests := []struct {
@ -101,6 +108,7 @@ func TestRequestPool_updatePresence(t *testing.T) {
rp := &RequestPool{
presence: &syncMap,
producer: publisher,
consumer: consumer,
cfg: &config.SyncAPI{
Matrix: &config.Global{
JetStream: config.JetStream{

View file

@ -70,8 +70,17 @@ func AddPublicRoutes(
Topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputPresenceEvent),
JetStream: js,
}
presenceConsumer := consumers.NewPresenceConsumer(
base.ProcessContext, cfg, js, natsClient, syncDB,
notifier, streams.PresenceStreamProvider,
userAPI,
)
requestPool := sync.NewRequestPool(syncDB, cfg, userAPI, keyAPI, rsAPI, streams, notifier, federationPresenceProducer, base.EnableMetrics)
requestPool := sync.NewRequestPool(syncDB, cfg, userAPI, keyAPI, rsAPI, streams, notifier, federationPresenceProducer, presenceConsumer, base.EnableMetrics)
if err = presenceConsumer.Start(); err != nil {
logrus.WithError(err).Panicf("failed to start presence consumer")
}
userAPIStreamEventProducer := &producers.UserAPIStreamEventProducer{
JetStream: js,
@ -137,15 +146,6 @@ func AddPublicRoutes(
logrus.WithError(err).Panicf("failed to start receipts consumer")
}
presenceConsumer := consumers.NewPresenceConsumer(
base.ProcessContext, cfg, js, natsClient, syncDB,
notifier, streams.PresenceStreamProvider,
userAPI,
)
if err = presenceConsumer.Start(); err != nil {
logrus.WithError(err).Panicf("failed to start presence consumer")
}
routing.Setup(
base.PublicClientAPIMux, requestPool, syncDB, userAPI,
rsAPI, cfg, base.Caches, fts,

View file

@ -15,9 +15,11 @@ import (
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
"github.com/tidwall/gjson"
)
type syncRoomserverAPI struct {
@ -86,7 +88,7 @@ func TestSyncAPIAccessTokens(t *testing.T) {
}
func testSyncAccessTokens(t *testing.T, dbType test.DBType) {
user := test.NewUser()
user := test.NewUser(t)
room := test.NewRoom(t, user)
alice := userapi.Device{
ID: "ALICEID",
@ -96,14 +98,14 @@ func testSyncAccessTokens(t *testing.T, dbType test.DBType) {
AccountType: userapi.AccountTypeUser,
}
base, close := test.CreateBaseDendrite(t, dbType)
base, close := testrig.CreateBaseDendrite(t, dbType)
defer close()
jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream)
defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream)
msgs := toNATSMsgs(t, base, room.Events())
AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, &syncKeyAPI{})
test.MustPublishMsgs(t, jsctx, msgs...)
testrig.MustPublishMsgs(t, jsctx, msgs...)
testCases := []struct {
name string
@ -173,7 +175,7 @@ func TestSyncAPICreateRoomSyncEarly(t *testing.T) {
}
func testSyncAPICreateRoomSyncEarly(t *testing.T, dbType test.DBType) {
user := test.NewUser()
user := test.NewUser(t)
room := test.NewRoom(t, user)
alice := userapi.Device{
ID: "ALICEID",
@ -183,7 +185,7 @@ func testSyncAPICreateRoomSyncEarly(t *testing.T, dbType test.DBType) {
AccountType: userapi.AccountTypeUser,
}
base, close := test.CreateBaseDendrite(t, dbType)
base, close := testrig.CreateBaseDendrite(t, dbType)
defer close()
jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream)
@ -198,7 +200,7 @@ func testSyncAPICreateRoomSyncEarly(t *testing.T, dbType test.DBType) {
sinceTokens := make([]string, len(msgs))
AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, &syncKeyAPI{})
for i, msg := range msgs {
test.MustPublishMsgs(t, jsctx, msg)
testrig.MustPublishMsgs(t, jsctx, msg)
time.Sleep(100 * time.Millisecond)
w := httptest.NewRecorder()
base.PublicClientAPIMux.ServeHTTP(w, test.NewRequest(t, "GET", "/_matrix/client/v3/sync", test.WithQueryParams(map[string]string{
@ -255,6 +257,60 @@ func testSyncAPICreateRoomSyncEarly(t *testing.T, dbType test.DBType) {
}
}
// Test that if we hit /sync we get back presence: online, regardless of whether messages get delivered
// via NATS. Regression test for a flakey test "User sees their own presence in a sync"
func TestSyncAPIUpdatePresenceImmediately(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
testSyncAPIUpdatePresenceImmediately(t, dbType)
})
}
func testSyncAPIUpdatePresenceImmediately(t *testing.T, dbType test.DBType) {
user := test.NewUser(t)
alice := userapi.Device{
ID: "ALICEID",
UserID: user.ID,
AccessToken: "ALICE_BEARER_TOKEN",
DisplayName: "Alice",
AccountType: userapi.AccountTypeUser,
}
base, close := testrig.CreateBaseDendrite(t, dbType)
base.Cfg.Global.Presence.EnableOutbound = true
base.Cfg.Global.Presence.EnableInbound = true
defer close()
jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream)
defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream)
AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{}, &syncKeyAPI{})
w := httptest.NewRecorder()
base.PublicClientAPIMux.ServeHTTP(w, test.NewRequest(t, "GET", "/_matrix/client/v3/sync", test.WithQueryParams(map[string]string{
"access_token": alice.AccessToken,
"timeout": "0",
"set_presence": "online",
})))
if w.Code != 200 {
t.Fatalf("got HTTP %d want %d", w.Code, 200)
}
var res types.Response
if err := json.NewDecoder(w.Body).Decode(&res); err != nil {
t.Errorf("failed to decode response body: %s", err)
}
if len(res.Presence.Events) != 1 {
t.Fatalf("expected 1 presence events, got: %+v", res.Presence.Events)
}
if res.Presence.Events[0].Sender != alice.UserID {
t.Errorf("sender: got %v want %v", res.Presence.Events[0].Sender, alice.UserID)
}
if res.Presence.Events[0].Type != "m.presence" {
t.Errorf("type: got %v want %v", res.Presence.Events[0].Type, "m.presence")
}
if gjson.ParseBytes(res.Presence.Events[0].Content).Get("presence").Str != "online" {
t.Errorf("content: not online, got %v", res.Presence.Events[0].Content)
}
}
func toNATSMsgs(t *testing.T, base *base.BaseDendrite, input []*gomatrixserverlib.HeaderedEvent) []*nats.Msg {
result := make([]*nats.Msg, len(input))
for i, ev := range input {
@ -262,7 +318,7 @@ func toNATSMsgs(t *testing.T, base *base.BaseDendrite, input []*gomatrixserverli
if ev.StateKey() != nil {
addsStateIDs = append(addsStateIDs, ev.EventID())
}
result[i] = test.NewOutputEventMsg(t, base, ev.RoomID(), api.OutputEvent{
result[i] = testrig.NewOutputEventMsg(t, base, ev.RoomID(), api.OutputEvent{
Type: rsapi.OutputTypeNewRoomEvent,
NewRoomEvent: &rsapi.OutputNewRoomEvent{
Event: ev,

View file

@ -44,8 +44,9 @@ func fatalError(t *testing.T, format string, args ...interface{}) {
}
func createLocalDB(t *testing.T, dbName string) {
if !Quiet {
t.Log("Note: tests require a postgres install accessible to the current user")
if _, err := exec.LookPath("createdb"); err != nil {
fatalError(t, "Note: tests require a postgres install accessible to the current user")
return
}
createDB := exec.Command("createdb", dbName)
if !Quiet {
@ -63,6 +64,9 @@ func createRemoteDB(t *testing.T, dbName, user, connStr string) {
if err != nil {
fatalError(t, "failed to open postgres conn with connstr=%s : %s", connStr, err)
}
if err = db.Ping(); err != nil {
fatalError(t, "failed to open postgres conn with connstr=%s : %s", connStr, err)
}
_, err = db.Exec(fmt.Sprintf(`CREATE DATABASE %s;`, dbName))
if err != nil {
pqErr, ok := err.(*pq.Error)

View file

@ -52,6 +52,24 @@ func WithUnsigned(unsigned interface{}) eventModifier {
}
}
func WithKeyID(keyID gomatrixserverlib.KeyID) eventModifier {
return func(e *eventMods) {
e.keyID = keyID
}
}
func WithPrivateKey(pkey ed25519.PrivateKey) eventModifier {
return func(e *eventMods) {
e.privKey = pkey
}
}
func WithOrigin(origin gomatrixserverlib.ServerName) eventModifier {
return func(e *eventMods) {
e.origin = origin
}
}
// Reverse a list of events
func Reversed(in []*gomatrixserverlib.HeaderedEvent) []*gomatrixserverlib.HeaderedEvent {
out := make([]*gomatrixserverlib.HeaderedEvent, len(in))

View file

@ -2,10 +2,15 @@ package test
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/url"
"path/filepath"
"sync"
"testing"
)
@ -43,3 +48,45 @@ func NewRequest(t *testing.T, method, path string, opts ...HTTPRequestOpt) *http
}
return req
}
// ListenAndServe will listen on a random high-numbered port and attach the given router.
// Returns the base URL to send requests to. Call `cancel` to shutdown the server, which will block until it has closed.
func ListenAndServe(t *testing.T, router http.Handler, withTLS bool) (apiURL string, cancel func()) {
listener, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatalf("failed to listen: %s", err)
}
port := listener.Addr().(*net.TCPAddr).Port
srv := http.Server{}
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
srv.Handler = router
var err error
if withTLS {
certFile := filepath.Join(t.TempDir(), "dendrite.cert")
keyFile := filepath.Join(t.TempDir(), "dendrite.key")
err = NewTLSKey(keyFile, certFile)
if err != nil {
t.Errorf("failed to make TLS key: %s", err)
return
}
err = srv.ServeTLS(listener, certFile, keyFile)
} else {
err = srv.Serve(listener)
}
if err != nil && err != http.ErrServerClosed {
t.Logf("Listen failed: %s", err)
}
}()
s := ""
if withTLS {
s = "s"
}
return fmt.Sprintf("http%s://localhost:%d", s, port), func() {
_ = srv.Shutdown(context.Background())
wg.Wait()
}
}

View file

@ -25,103 +25,19 @@ import (
"io/ioutil"
"math/big"
"os"
"path/filepath"
"strings"
"time"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
"gopkg.in/yaml.v2"
)
const (
// ConfigFile is the name of the config file for a server.
ConfigFile = "dendrite.yaml"
// ServerKeyFile is the name of the file holding the matrix server private key.
ServerKeyFile = "server_key.pem"
// TLSCertFile is the name of the file holding the TLS certificate used for federation.
TLSCertFile = "tls_cert.pem"
// TLSKeyFile is the name of the file holding the TLS key used for federation.
TLSKeyFile = "tls_key.pem"
// MediaDir is the name of the directory used to store media.
MediaDir = "media"
)
// MakeConfig makes a config suitable for running integration tests.
// Generates new matrix and TLS keys for the server.
func MakeConfig(configDir, kafkaURI, database, host string, startPort int) (*config.Dendrite, int, error) {
var cfg config.Dendrite
cfg.Defaults(true)
port := startPort
assignAddress := func() config.HTTPAddress {
result := config.HTTPAddress(fmt.Sprintf("http://%s:%d", host, port))
port++
return result
}
serverKeyPath := filepath.Join(configDir, ServerKeyFile)
tlsCertPath := filepath.Join(configDir, TLSKeyFile)
tlsKeyPath := filepath.Join(configDir, TLSCertFile)
mediaBasePath := filepath.Join(configDir, MediaDir)
if err := NewMatrixKey(serverKeyPath); err != nil {
return nil, 0, err
}
if err := NewTLSKey(tlsKeyPath, tlsCertPath); err != nil {
return nil, 0, err
}
cfg.Version = config.Version
cfg.Global.ServerName = gomatrixserverlib.ServerName(assignAddress())
cfg.Global.PrivateKeyPath = config.Path(serverKeyPath)
cfg.MediaAPI.BasePath = config.Path(mediaBasePath)
cfg.Global.JetStream.Addresses = []string{kafkaURI}
// TODO: Use different databases for the different schemas.
// Using the same database for every schema currently works because
// the table names are globally unique. But we might not want to
// rely on that in the future.
cfg.AppServiceAPI.Database.ConnectionString = config.DataSource(database)
cfg.FederationAPI.Database.ConnectionString = config.DataSource(database)
cfg.KeyServer.Database.ConnectionString = config.DataSource(database)
cfg.MediaAPI.Database.ConnectionString = config.DataSource(database)
cfg.RoomServer.Database.ConnectionString = config.DataSource(database)
cfg.SyncAPI.Database.ConnectionString = config.DataSource(database)
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(database)
cfg.AppServiceAPI.InternalAPI.Listen = assignAddress()
cfg.FederationAPI.InternalAPI.Listen = assignAddress()
cfg.KeyServer.InternalAPI.Listen = assignAddress()
cfg.MediaAPI.InternalAPI.Listen = assignAddress()
cfg.RoomServer.InternalAPI.Listen = assignAddress()
cfg.SyncAPI.InternalAPI.Listen = assignAddress()
cfg.UserAPI.InternalAPI.Listen = assignAddress()
cfg.AppServiceAPI.InternalAPI.Connect = cfg.AppServiceAPI.InternalAPI.Listen
cfg.FederationAPI.InternalAPI.Connect = cfg.FederationAPI.InternalAPI.Listen
cfg.KeyServer.InternalAPI.Connect = cfg.KeyServer.InternalAPI.Listen
cfg.MediaAPI.InternalAPI.Connect = cfg.MediaAPI.InternalAPI.Listen
cfg.RoomServer.InternalAPI.Connect = cfg.RoomServer.InternalAPI.Listen
cfg.SyncAPI.InternalAPI.Connect = cfg.SyncAPI.InternalAPI.Listen
cfg.UserAPI.InternalAPI.Connect = cfg.UserAPI.InternalAPI.Listen
return &cfg, port, nil
}
// WriteConfig writes the config file to the directory.
func WriteConfig(cfg *config.Dendrite, configDir string) error {
data, err := yaml.Marshal(cfg)
if err != nil {
return err
}
return ioutil.WriteFile(filepath.Join(configDir, ConfigFile), data, 0666)
}
// NewMatrixKey generates a new ed25519 matrix server key and writes it to a file.
func NewMatrixKey(matrixKeyPath string) (err error) {
var data [35]byte

View file

@ -15,7 +15,6 @@
package test
import (
"crypto/ed25519"
"encoding/json"
"fmt"
"sync/atomic"
@ -35,12 +34,6 @@ var (
PresetTrustedPrivateChat Preset = 3
roomIDCounter = int64(0)
testKeyID = gomatrixserverlib.KeyID("ed25519:test")
testPrivateKey = ed25519.NewKeyFromSeed([]byte{
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,
})
)
type Room struct {
@ -50,6 +43,7 @@ type Room struct {
creator *User
authEvents gomatrixserverlib.AuthEvents
currentState map[string]*gomatrixserverlib.HeaderedEvent
events []*gomatrixserverlib.HeaderedEvent
}
@ -57,14 +51,16 @@ type Room struct {
func NewRoom(t *testing.T, creator *User, modifiers ...roomModifier) *Room {
t.Helper()
counter := atomic.AddInt64(&roomIDCounter, 1)
// set defaults then let roomModifiers override
if creator.srvName == "" {
t.Fatalf("NewRoom: creator doesn't belong to a server: %+v", *creator)
}
r := &Room{
ID: fmt.Sprintf("!%d:localhost", counter),
ID: fmt.Sprintf("!%d:%s", counter, creator.srvName),
creator: creator,
authEvents: gomatrixserverlib.NewAuthEvents(nil),
preset: PresetPublicChat,
Version: gomatrixserverlib.RoomVersionV9,
currentState: make(map[string]*gomatrixserverlib.HeaderedEvent),
}
for _, m := range modifiers {
m(t, r)
@ -73,6 +69,24 @@ func NewRoom(t *testing.T, creator *User, modifiers ...roomModifier) *Room {
return r
}
func (r *Room) MustGetAuthEventRefsForEvent(t *testing.T, needed gomatrixserverlib.StateNeeded) []gomatrixserverlib.EventReference {
t.Helper()
a, err := needed.AuthEventReferences(&r.authEvents)
if err != nil {
t.Fatalf("MustGetAuthEvents: %v", err)
}
return a
}
func (r *Room) ForwardExtremities() []string {
if len(r.events) == 0 {
return nil
}
return []string{
r.events[len(r.events)-1].EventID(),
}
}
func (r *Room) insertCreateEvents(t *testing.T) {
t.Helper()
var joinRule gomatrixserverlib.JoinRuleContent
@ -88,6 +102,7 @@ func (r *Room) insertCreateEvents(t *testing.T) {
joinRule.JoinRule = "public"
hisVis.HistoryVisibility = "shared"
}
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomCreate, map[string]interface{}{
"creator": r.creator.ID,
"room_version": r.Version,
@ -112,16 +127,16 @@ func (r *Room) CreateEvent(t *testing.T, creator *User, eventType string, conten
}
if mod.privKey == nil {
mod.privKey = testPrivateKey
mod.privKey = creator.privKey
}
if mod.keyID == "" {
mod.keyID = testKeyID
mod.keyID = creator.keyID
}
if mod.originServerTS.IsZero() {
mod.originServerTS = time.Now()
}
if mod.origin == "" {
mod.origin = gomatrixserverlib.ServerName("localhost")
mod.origin = creator.srvName
}
var unsigned gomatrixserverlib.RawJSON
@ -174,13 +189,14 @@ func (r *Room) CreateEvent(t *testing.T, creator *User, eventType string, conten
// Add a new event to this room DAG. Not thread-safe.
func (r *Room) InsertEvent(t *testing.T, he *gomatrixserverlib.HeaderedEvent) {
t.Helper()
// Add the event to the list of auth events
// Add the event to the list of auth/state events
r.events = append(r.events, he)
if he.StateKey() != nil {
err := r.authEvents.AddEvent(he.Unwrap())
if err != nil {
t.Fatalf("InsertEvent: failed to add event to auth events: %s", err)
}
r.currentState[he.Type()+" "+*he.StateKey()] = he
}
}
@ -188,6 +204,16 @@ func (r *Room) Events() []*gomatrixserverlib.HeaderedEvent {
return r.events
}
func (r *Room) CurrentState() []*gomatrixserverlib.HeaderedEvent {
events := make([]*gomatrixserverlib.HeaderedEvent, len(r.currentState))
i := 0
for _, e := range r.currentState {
events[i] = e
i++
}
return events
}
func (r *Room) CreateAndInsert(t *testing.T, creator *User, eventType string, content interface{}, mods ...eventModifier) *gomatrixserverlib.HeaderedEvent {
t.Helper()
he := r.CreateEvent(t, creator, eventType, content, mods...)

View file

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package test
package testrig
import (
"errors"
@ -24,22 +24,23 @@ import (
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test"
"github.com/nats-io/nats.go"
)
func CreateBaseDendrite(t *testing.T, dbType DBType) (*base.BaseDendrite, func()) {
func CreateBaseDendrite(t *testing.T, dbType test.DBType) (*base.BaseDendrite, func()) {
var cfg config.Dendrite
cfg.Defaults(false)
cfg.Global.JetStream.InMemory = true
switch dbType {
case DBTypePostgres:
case test.DBTypePostgres:
cfg.Global.Defaults(true) // autogen a signing key
cfg.MediaAPI.Defaults(true) // autogen a media path
// use a distinct prefix else concurrent postgres/sqlite runs will clash since NATS will use
// the file system event with InMemory=true :(
cfg.Global.JetStream.TopicPrefix = fmt.Sprintf("Test_%d_", dbType)
connStr, close := PrepareDBConnectionString(t, dbType)
connStr, close := test.PrepareDBConnectionString(t, dbType)
cfg.Global.DatabaseOptions = config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
MaxOpenConnections: 10,
@ -47,7 +48,7 @@ func CreateBaseDendrite(t *testing.T, dbType DBType) (*base.BaseDendrite, func()
ConnMaxLifetimeSeconds: 60,
}
return base.NewBaseDendrite(&cfg, "Test", base.DisableMetrics), close
case DBTypeSQLite:
case test.DBTypeSQLite:
cfg.Defaults(true) // sets a sqlite db per component
// use a distinct prefix else concurrent postgres/sqlite runs will clash since NATS will use
// the file system event with InMemory=true :(

View file

@ -1,4 +1,4 @@
package test
package testrig
import (
"encoding/json"

View file

@ -15,22 +15,64 @@
package test
import (
"crypto/ed25519"
"fmt"
"sync/atomic"
"testing"
"github.com/matrix-org/gomatrixserverlib"
)
var (
userIDCounter = int64(0)
serverName = gomatrixserverlib.ServerName("test")
keyID = gomatrixserverlib.KeyID("ed25519:test")
privateKey = ed25519.NewKeyFromSeed([]byte{
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,
})
// private keys that tests can use
PrivateKeyA = ed25519.NewKeyFromSeed([]byte{
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 77,
})
PrivateKeyB = ed25519.NewKeyFromSeed([]byte{
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 66,
})
)
type User struct {
ID string
// key ID and private key of the server who has this user, if known.
keyID gomatrixserverlib.KeyID
privKey ed25519.PrivateKey
srvName gomatrixserverlib.ServerName
}
func NewUser() *User {
type UserOpt func(*User)
func WithSigningServer(srvName gomatrixserverlib.ServerName, keyID gomatrixserverlib.KeyID, privKey ed25519.PrivateKey) UserOpt {
return func(u *User) {
u.keyID = keyID
u.privKey = privKey
u.srvName = srvName
}
}
func NewUser(t *testing.T, opts ...UserOpt) *User {
counter := atomic.AddInt64(&userIDCounter, 1)
u := &User{
ID: fmt.Sprintf("@%d:localhost", counter),
var u User
for _, opt := range opts {
opt(&u)
}
return u
if u.keyID == "" || u.srvName == "" || u.privKey == nil {
t.Logf("NewUser: missing signing server credentials; using default.")
WithSigningServer(serverName, keyID, privateKey)(&u)
}
u.ID = fmt.Sprintf("@%d:%s", counter, u.srvName)
t.Logf("NewUser: created user %s", u.ID)
return &u
}

View file

@ -43,7 +43,7 @@ func Test_AccountData(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateDatabase(t, dbType)
defer close()
alice := test.NewUser()
alice := test.NewUser(t)
localpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
assert.NoError(t, err)
@ -74,7 +74,7 @@ func Test_Accounts(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateDatabase(t, dbType)
defer close()
alice := test.NewUser()
alice := test.NewUser(t)
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
assert.NoError(t, err)
@ -128,7 +128,7 @@ func Test_Accounts(t *testing.T) {
}
func Test_Devices(t *testing.T) {
alice := test.NewUser()
alice := test.NewUser(t)
localpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
assert.NoError(t, err)
deviceID := util.RandomString(8)
@ -212,7 +212,7 @@ func Test_Devices(t *testing.T) {
}
func Test_KeyBackup(t *testing.T) {
alice := test.NewUser()
alice := test.NewUser(t)
room := test.NewRoom(t, alice)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
@ -291,7 +291,7 @@ func Test_KeyBackup(t *testing.T) {
}
func Test_LoginToken(t *testing.T) {
alice := test.NewUser()
alice := test.NewUser(t)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateDatabase(t, dbType)
defer close()
@ -321,7 +321,7 @@ func Test_LoginToken(t *testing.T) {
}
func Test_OpenID(t *testing.T) {
alice := test.NewUser()
alice := test.NewUser(t)
token := util.RandomString(24)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
@ -341,7 +341,7 @@ func Test_OpenID(t *testing.T) {
}
func Test_Profile(t *testing.T) {
alice := test.NewUser()
alice := test.NewUser(t)
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
assert.NoError(t, err)
@ -379,7 +379,7 @@ func Test_Profile(t *testing.T) {
}
func Test_Pusher(t *testing.T) {
alice := test.NewUser()
alice := test.NewUser(t)
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
assert.NoError(t, err)
@ -430,7 +430,7 @@ func Test_Pusher(t *testing.T) {
}
func Test_ThreePID(t *testing.T) {
alice := test.NewUser()
alice := test.NewUser(t)
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
assert.NoError(t, err)
@ -467,7 +467,7 @@ func Test_ThreePID(t *testing.T) {
}
func Test_Notification(t *testing.T) {
alice := test.NewUser()
alice := test.NewUser(t)
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
assert.NoError(t, err)
room := test.NewRoom(t, alice)

View file

@ -24,7 +24,6 @@ import (
"github.com/gorilla/mux"
"github.com/matrix-org/dendrite/internal/httputil"
internalTest "github.com/matrix-org/dendrite/internal/test"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/userapi"
"github.com/matrix-org/dendrite/userapi/inthttp"
@ -135,7 +134,7 @@ func TestQueryProfile(t *testing.T) {
t.Run("HTTP API", func(t *testing.T) {
router := mux.NewRouter().PathPrefix(httputil.InternalPathPrefix).Subrouter()
userapi.AddInternalRoutes(router, userAPI)
apiURL, cancel := internalTest.ListenAndServe(t, router, false)
apiURL, cancel := test.ListenAndServe(t, router, false)
defer cancel()
httpAPI, err := inthttp.NewUserAPIClient(apiURL, &http.Client{})
if err != nil {