Merge branch 'main' of github.com:matrix-org/dendrite into s7evink/fts

This commit is contained in:
Till Faelligen 2022-05-18 15:25:41 +02:00
commit a6f2b46c98
73 changed files with 1317 additions and 951 deletions

View file

@ -20,7 +20,7 @@ import (
"log" "log"
"os" "os"
"github.com/matrix-org/dendrite/internal/test" "github.com/matrix-org/dendrite/test"
) )
const usage = `Usage: %s const usage = `Usage: %s

84
docs/coverage.md Normal file
View file

@ -0,0 +1,84 @@
---
title: Coverage
parent: Development
permalink: /development/coverage
---
To generate a test coverage report for Sytest, a small patch needs to be applied to the Sytest repository to compile and use the instrumented binary:
```patch
diff --git a/lib/SyTest/Homeserver/Dendrite.pm b/lib/SyTest/Homeserver/Dendrite.pm
index 8f0e209c..ad057e52 100644
--- a/lib/SyTest/Homeserver/Dendrite.pm
+++ b/lib/SyTest/Homeserver/Dendrite.pm
@@ -337,7 +337,7 @@ sub _start_monolith
$output->diag( "Starting monolith server" );
my @command = (
- $self->{bindir} . '/dendrite-monolith-server',
+ $self->{bindir} . '/dendrite-monolith-server', '--test.coverprofile=' . $self->{hs_dir} . '/integrationcover.log', "DEVEL",
'--config', $self->{paths}{config},
'--http-bind-address', $self->{bind_host} . ':' . $self->unsecure_port,
'--https-bind-address', $self->{bind_host} . ':' . $self->secure_port,
diff --git a/scripts/dendrite_sytest.sh b/scripts/dendrite_sytest.sh
index f009332b..7ea79869 100755
--- a/scripts/dendrite_sytest.sh
+++ b/scripts/dendrite_sytest.sh
@@ -34,7 +34,8 @@ export GOBIN=/tmp/bin
echo >&2 "--- Building dendrite from source"
cd /src
mkdir -p $GOBIN
-go install -v ./cmd/dendrite-monolith-server
+# go install -v ./cmd/dendrite-monolith-server
+go test -c -cover -covermode=atomic -o $GOBIN/dendrite-monolith-server -coverpkg "github.com/matrix-org/..." ./cmd/dendrite-monolith-server
go install -v ./cmd/generate-keys
cd -
```
Then run Sytest. This will generate a new file `integrationcover.log` in each server's directory e.g `server-0/integrationcover.log`. To parse it,
ensure your working directory is under the Dendrite repository then run:
```bash
go tool cover -func=/path/to/server-0/integrationcover.log
```
which will produce an output like:
```
...
github.com/matrix-org/util/json.go:83: NewJSONRequestHandler 100.0%
github.com/matrix-org/util/json.go:90: Protect 57.1%
github.com/matrix-org/util/json.go:110: RequestWithLogging 100.0%
github.com/matrix-org/util/json.go:132: MakeJSONAPI 70.0%
github.com/matrix-org/util/json.go:151: respond 61.5%
github.com/matrix-org/util/json.go:180: WithCORSOptions 0.0%
github.com/matrix-org/util/json.go:191: SetCORSHeaders 100.0%
github.com/matrix-org/util/json.go:202: RandomString 100.0%
github.com/matrix-org/util/json.go:210: init 100.0%
github.com/matrix-org/util/unique.go:13: Unique 91.7%
github.com/matrix-org/util/unique.go:48: SortAndUnique 100.0%
github.com/matrix-org/util/unique.go:55: UniqueStrings 100.0%
total: (statements) 53.7%
```
The total coverage for this run is the last line at the bottom. However, this value is misleading because Dendrite can run in many different configurations,
which will never be tested in a single test run (e.g sqlite or postgres, monolith or polylith). To get a more accurate value, additional processing is required
to remove packages which will never be tested and extension MSCs:
```bash
# These commands are all similar but change which package paths are _removed_ from the output.
# For Postgres (monolith)
go tool cover -func=/path/to/server-0/integrationcover.log | grep 'github.com/matrix-org/dendrite' | grep -Ev 'inthttp|sqlite|setup/mscs|api_trace' > coverage.txt
# For Postgres (polylith)
go tool cover -func=/path/to/server-0/integrationcover.log | grep 'github.com/matrix-org/dendrite' | grep -Ev 'sqlite|setup/mscs|api_trace' > coverage.txt
# For SQLite (monolith)
go tool cover -func=/path/to/server-0/integrationcover.log | grep 'github.com/matrix-org/dendrite' | grep -Ev 'inthttp|postgres|setup/mscs|api_trace' > coverage.txt
# For SQLite (polylith)
go tool cover -func=/path/to/server-0/integrationcover.log | grep 'github.com/matrix-org/dendrite' | grep -Ev 'postgres|setup/mscs|api_trace' > coverage.txt
```
A total value can then be calculated using:
```bash
cat coverage.txt | awk -F '\t+' '{x = x + $3} END {print x/NR}'
```
We currently do not have a way to combine Sytest/Complement/Unit Tests into a single coverage report.

View file

@ -12,12 +12,16 @@ import (
// FederationInternalAPI is used to query information from the federation sender. // FederationInternalAPI is used to query information from the federation sender.
type FederationInternalAPI interface { type FederationInternalAPI interface {
FederationClient gomatrixserverlib.FederatedStateClient
KeyserverFederationAPI
gomatrixserverlib.KeyDatabase gomatrixserverlib.KeyDatabase
ClientFederationAPI ClientFederationAPI
RoomserverFederationAPI RoomserverFederationAPI
QueryServerKeys(ctx context.Context, request *QueryServerKeysRequest, response *QueryServerKeysResponse) error QueryServerKeys(ctx context.Context, request *QueryServerKeysRequest, response *QueryServerKeysResponse) error
LookupServerKeys(ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) ([]gomatrixserverlib.ServerKeys, error)
MSC2836EventRelationships(ctx context.Context, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error)
MSC2946Spaces(ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool) (res gomatrixserverlib.MSC2946SpacesResponse, err error)
// Broadcasts an EDU to all servers in rooms we are joined to. Used in the yggdrasil demos. // Broadcasts an EDU to all servers in rooms we are joined to. Used in the yggdrasil demos.
PerformBroadcastEDU( PerformBroadcastEDU(
@ -60,17 +64,43 @@ type RoomserverFederationAPI interface {
LookupMissingEvents(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) LookupMissingEvents(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error)
} }
// FederationClient is a subset of gomatrixserverlib.FederationClient functions which the fedsender // KeyserverFederationAPI is a subset of gomatrixserverlib.FederationClient functions which the keyserver
// implements as proxy calls, with built-in backoff/retries/etc. Errors returned from functions in // implements as proxy calls, with built-in backoff/retries/etc. Errors returned from functions in
// this interface are of type FederationClientError // this interface are of type FederationClientError
type FederationClient interface { type KeyserverFederationAPI interface {
gomatrixserverlib.FederatedStateClient
GetUserDevices(ctx context.Context, s gomatrixserverlib.ServerName, userID string) (res gomatrixserverlib.RespUserDevices, err error) GetUserDevices(ctx context.Context, s gomatrixserverlib.ServerName, userID string) (res gomatrixserverlib.RespUserDevices, err error)
ClaimKeys(ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (res gomatrixserverlib.RespClaimKeys, err error) ClaimKeys(ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (res gomatrixserverlib.RespClaimKeys, err error)
QueryKeys(ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string) (res gomatrixserverlib.RespQueryKeys, err error) QueryKeys(ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string) (res gomatrixserverlib.RespQueryKeys, err error)
}
// an interface for gmsl.FederationClient - contains functions called by federationapi only.
type FederationClient interface {
gomatrixserverlib.KeyClient
SendTransaction(ctx context.Context, t gomatrixserverlib.Transaction) (res gomatrixserverlib.RespSend, err error)
// Perform operations
LookupRoomAlias(ctx context.Context, s gomatrixserverlib.ServerName, roomAlias string) (res gomatrixserverlib.RespDirectory, err error)
Peek(ctx context.Context, s gomatrixserverlib.ServerName, roomID, peekID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespPeek, err error)
MakeJoin(ctx context.Context, s gomatrixserverlib.ServerName, roomID, userID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMakeJoin, err error)
SendJoin(ctx context.Context, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (res gomatrixserverlib.RespSendJoin, err error)
MakeLeave(ctx context.Context, s gomatrixserverlib.ServerName, roomID, userID string) (res gomatrixserverlib.RespMakeLeave, err error)
SendLeave(ctx context.Context, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (err error)
SendInviteV2(ctx context.Context, s gomatrixserverlib.ServerName, request gomatrixserverlib.InviteV2Request) (res gomatrixserverlib.RespInviteV2, err error)
GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error)
GetEventAuth(ctx context.Context, s gomatrixserverlib.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (res gomatrixserverlib.RespEventAuth, err error)
GetUserDevices(ctx context.Context, s gomatrixserverlib.ServerName, userID string) (gomatrixserverlib.RespUserDevices, error)
ClaimKeys(ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (gomatrixserverlib.RespClaimKeys, error)
QueryKeys(ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string) (gomatrixserverlib.RespQueryKeys, error)
Backfill(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string) (res gomatrixserverlib.Transaction, err error)
MSC2836EventRelationships(ctx context.Context, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) MSC2836EventRelationships(ctx context.Context, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error)
MSC2946Spaces(ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool) (res gomatrixserverlib.MSC2946SpacesResponse, err error) MSC2946Spaces(ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool) (res gomatrixserverlib.MSC2946SpacesResponse, err error)
LookupServerKeys(ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) ([]gomatrixserverlib.ServerKeys, error)
ExchangeThirdPartyInvite(ctx context.Context, s gomatrixserverlib.ServerName, builder gomatrixserverlib.EventBuilder) (err error)
LookupState(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespState, err error)
LookupStateIDs(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error)
LookupMissingEvents(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error)
} }
// FederationClientError is returned from FederationClient methods in the event of a problem. // FederationClientError is returned from FederationClient methods in the event of a problem.

View file

@ -39,7 +39,7 @@ type KeyChangeConsumer struct {
db storage.Database db storage.Database
queues *queue.OutgoingQueues queues *queue.OutgoingQueues
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
rsAPI roomserverAPI.RoomserverInternalAPI rsAPI roomserverAPI.FederationRoomserverAPI
topic string topic string
} }
@ -50,7 +50,7 @@ func NewKeyChangeConsumer(
js nats.JetStreamContext, js nats.JetStreamContext,
queues *queue.OutgoingQueues, queues *queue.OutgoingQueues,
store storage.Database, store storage.Database,
rsAPI roomserverAPI.RoomserverInternalAPI, rsAPI roomserverAPI.FederationRoomserverAPI,
) *KeyChangeConsumer { ) *KeyChangeConsumer {
return &KeyChangeConsumer{ return &KeyChangeConsumer{
ctx: process.Context(), ctx: process.Context(),
@ -120,6 +120,7 @@ func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) bool {
logger.WithError(err).Error("failed to calculate joined rooms for user") logger.WithError(err).Error("failed to calculate joined rooms for user")
return true return true
} }
// send this key change to all servers who share rooms with this user. // send this key change to all servers who share rooms with this user.
destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true) destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true)
if err != nil { if err != nil {

View file

@ -36,7 +36,7 @@ import (
type OutputRoomEventConsumer struct { type OutputRoomEventConsumer struct {
ctx context.Context ctx context.Context
cfg *config.FederationAPI cfg *config.FederationAPI
rsAPI api.RoomserverInternalAPI rsAPI api.FederationRoomserverAPI
jetstream nats.JetStreamContext jetstream nats.JetStreamContext
durable string durable string
db storage.Database db storage.Database
@ -51,7 +51,7 @@ func NewOutputRoomEventConsumer(
js nats.JetStreamContext, js nats.JetStreamContext,
queues *queue.OutgoingQueues, queues *queue.OutgoingQueues,
store storage.Database, store storage.Database,
rsAPI api.RoomserverInternalAPI, rsAPI api.FederationRoomserverAPI,
) *OutputRoomEventConsumer { ) *OutputRoomEventConsumer {
return &OutputRoomEventConsumer{ return &OutputRoomEventConsumer{
ctx: process.Context(), ctx: process.Context(),
@ -89,15 +89,7 @@ func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msg *nats.Msg)
switch output.Type { switch output.Type {
case api.OutputTypeNewRoomEvent: case api.OutputTypeNewRoomEvent:
ev := output.NewRoomEvent.Event ev := output.NewRoomEvent.Event
if err := s.processMessage(*output.NewRoomEvent, output.NewRoomEvent.RewritesState); err != nil {
if output.NewRoomEvent.RewritesState {
if err := s.db.PurgeRoomState(s.ctx, ev.RoomID()); err != nil {
log.WithError(err).Errorf("roomserver output log: purge room state failure")
return false
}
}
if err := s.processMessage(*output.NewRoomEvent); err != nil {
// panic rather than continue with an inconsistent database // panic rather than continue with an inconsistent database
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"event_id": ev.EventID(), "event_id": ev.EventID(),
@ -145,7 +137,7 @@ func (s *OutputRoomEventConsumer) processInboundPeek(orp api.OutputNewInboundPee
// processMessage updates the list of currently joined hosts in the room // processMessage updates the list of currently joined hosts in the room
// and then sends the event to the hosts that were joined before the event. // and then sends the event to the hosts that were joined before the event.
func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent) error { func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent, rewritesState bool) error {
addsStateEvents, missingEventIDs := ore.NeededStateEventIDs() addsStateEvents, missingEventIDs := ore.NeededStateEventIDs()
// Ask the roomserver and add in the rest of the results into the set. // Ask the roomserver and add in the rest of the results into the set.
@ -164,7 +156,7 @@ func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent) err
addsStateEvents = append(addsStateEvents, eventsRes.Events...) addsStateEvents = append(addsStateEvents, eventsRes.Events...)
} }
addsJoinedHosts, err := joinedHostsFromEvents(gomatrixserverlib.UnwrapEventHeaders(addsStateEvents)) addsJoinedHosts, err := JoinedHostsFromEvents(gomatrixserverlib.UnwrapEventHeaders(addsStateEvents))
if err != nil { if err != nil {
return err return err
} }
@ -176,10 +168,9 @@ func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent) err
oldJoinedHosts, err := s.db.UpdateRoom( oldJoinedHosts, err := s.db.UpdateRoom(
s.ctx, s.ctx,
ore.Event.RoomID(), ore.Event.RoomID(),
ore.LastSentEventID,
ore.Event.EventID(),
addsJoinedHosts, addsJoinedHosts,
ore.RemovesStateEventIDs, ore.RemovesStateEventIDs,
rewritesState, // if we're re-writing state, nuke all joined hosts before adding
) )
if err != nil { if err != nil {
return err return err
@ -238,7 +229,7 @@ func (s *OutputRoomEventConsumer) joinedHostsAtEvent(
return nil, err return nil, err
} }
combinedAddsJoinedHosts, err := joinedHostsFromEvents(combinedAddsEvents) combinedAddsJoinedHosts, err := JoinedHostsFromEvents(combinedAddsEvents)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -284,10 +275,10 @@ func (s *OutputRoomEventConsumer) joinedHostsAtEvent(
return result, nil return result, nil
} }
// joinedHostsFromEvents turns a list of state events into a list of joined hosts. // JoinedHostsFromEvents turns a list of state events into a list of joined hosts.
// This errors if one of the events was invalid. // This errors if one of the events was invalid.
// It should be impossible for an invalid event to get this far in the pipeline. // It should be impossible for an invalid event to get this far in the pipeline.
func joinedHostsFromEvents(evs []*gomatrixserverlib.Event) ([]types.JoinedHost, error) { func JoinedHostsFromEvents(evs []*gomatrixserverlib.Event) ([]types.JoinedHost, error) {
var joinedHosts []types.JoinedHost var joinedHosts []types.JoinedHost
for _, ev := range evs { for _, ev := range evs {
if ev.Type() != "m.room.member" || ev.StateKey() == nil { if ev.Type() != "m.room.member" || ev.StateKey() == nil {

View file

@ -93,8 +93,8 @@ func AddPublicRoutes(
// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. // can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes.
func NewInternalAPI( func NewInternalAPI(
base *base.BaseDendrite, base *base.BaseDendrite,
federation *gomatrixserverlib.FederationClient, federation api.FederationClient,
rsAPI roomserverAPI.RoomserverInternalAPI, rsAPI roomserverAPI.FederationRoomserverAPI,
caches *caching.Caches, caches *caching.Caches,
keyRing *gomatrixserverlib.KeyRing, keyRing *gomatrixserverlib.KeyRing,
resetBlacklist bool, resetBlacklist bool,

View file

@ -3,18 +3,250 @@ package federationapi_test
import ( import (
"context" "context"
"crypto/ed25519" "crypto/ed25519"
"encoding/json"
"fmt"
"strings" "strings"
"testing" "testing"
"time"
"github.com/matrix-org/dendrite/federationapi" "github.com/matrix-org/dendrite/federationapi"
"github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/federationapi/internal" "github.com/matrix-org/dendrite/federationapi/internal"
"github.com/matrix-org/dendrite/internal/test" keyapi "github.com/matrix-org/dendrite/keyserver/api"
rsapi "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig"
"github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrix"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
) )
type fedRoomserverAPI struct {
rsapi.FederationRoomserverAPI
inputRoomEvents func(ctx context.Context, req *rsapi.InputRoomEventsRequest, res *rsapi.InputRoomEventsResponse)
queryRoomsForUser func(ctx context.Context, req *rsapi.QueryRoomsForUserRequest, res *rsapi.QueryRoomsForUserResponse) error
}
// PerformJoin will call this function
func (f *fedRoomserverAPI) InputRoomEvents(ctx context.Context, req *rsapi.InputRoomEventsRequest, res *rsapi.InputRoomEventsResponse) {
if f.inputRoomEvents == nil {
return
}
f.inputRoomEvents(ctx, req, res)
}
// keychange consumer calls this
func (f *fedRoomserverAPI) QueryRoomsForUser(ctx context.Context, req *rsapi.QueryRoomsForUserRequest, res *rsapi.QueryRoomsForUserResponse) error {
if f.queryRoomsForUser == nil {
return nil
}
return f.queryRoomsForUser(ctx, req, res)
}
// TODO: This struct isn't generic, only works for TestFederationAPIJoinThenKeyUpdate
type fedClient struct {
api.FederationClient
allowJoins []*test.Room
keys map[gomatrixserverlib.ServerName]struct {
key ed25519.PrivateKey
keyID gomatrixserverlib.KeyID
}
t *testing.T
sentTxn bool
}
func (f *fedClient) GetServerKeys(ctx context.Context, matrixServer gomatrixserverlib.ServerName) (gomatrixserverlib.ServerKeys, error) {
fmt.Println("GetServerKeys:", matrixServer)
var keys gomatrixserverlib.ServerKeys
var keyID gomatrixserverlib.KeyID
var pkey ed25519.PrivateKey
for srv, data := range f.keys {
if srv == matrixServer {
pkey = data.key
keyID = data.keyID
break
}
}
if pkey == nil {
return keys, nil
}
keys.ServerName = matrixServer
keys.ValidUntilTS = gomatrixserverlib.AsTimestamp(time.Now().Add(10 * time.Hour))
publicKey := pkey.Public().(ed25519.PublicKey)
keys.VerifyKeys = map[gomatrixserverlib.KeyID]gomatrixserverlib.VerifyKey{
keyID: {
Key: gomatrixserverlib.Base64Bytes(publicKey),
},
}
toSign, err := json.Marshal(keys.ServerKeyFields)
if err != nil {
return keys, err
}
keys.Raw, err = gomatrixserverlib.SignJSON(
string(matrixServer), keyID, pkey, toSign,
)
if err != nil {
return keys, err
}
return keys, nil
}
func (f *fedClient) MakeJoin(ctx context.Context, s gomatrixserverlib.ServerName, roomID, userID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMakeJoin, err error) {
for _, r := range f.allowJoins {
if r.ID == roomID {
res.RoomVersion = r.Version
res.JoinEvent = gomatrixserverlib.EventBuilder{
Sender: userID,
RoomID: roomID,
Type: "m.room.member",
StateKey: &userID,
Content: gomatrixserverlib.RawJSON([]byte(`{"membership":"join"}`)),
PrevEvents: r.ForwardExtremities(),
}
var needed gomatrixserverlib.StateNeeded
needed, err = gomatrixserverlib.StateNeededForEventBuilder(&res.JoinEvent)
if err != nil {
f.t.Errorf("StateNeededForEventBuilder: %v", err)
return
}
res.JoinEvent.AuthEvents = r.MustGetAuthEventRefsForEvent(f.t, needed)
return
}
}
return
}
func (f *fedClient) SendJoin(ctx context.Context, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (res gomatrixserverlib.RespSendJoin, err error) {
for _, r := range f.allowJoins {
if r.ID == event.RoomID() {
r.InsertEvent(f.t, event.Headered(r.Version))
f.t.Logf("Join event: %v", event.EventID())
res.StateEvents = gomatrixserverlib.NewEventJSONsFromHeaderedEvents(r.CurrentState())
res.AuthEvents = gomatrixserverlib.NewEventJSONsFromHeaderedEvents(r.Events())
}
}
return
}
func (f *fedClient) SendTransaction(ctx context.Context, t gomatrixserverlib.Transaction) (res gomatrixserverlib.RespSend, err error) {
for _, edu := range t.EDUs {
if edu.Type == gomatrixserverlib.MDeviceListUpdate {
f.sentTxn = true
}
}
f.t.Logf("got /send")
return
}
// Regression test to make sure that /send_join is updating the destination hosts synchronously and
// isn't relying on the roomserver.
func TestFederationAPIJoinThenKeyUpdate(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
testFederationAPIJoinThenKeyUpdate(t, dbType)
})
}
func testFederationAPIJoinThenKeyUpdate(t *testing.T, dbType test.DBType) {
base, close := testrig.CreateBaseDendrite(t, dbType)
base.Cfg.FederationAPI.PreferDirectFetch = true
defer close()
jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream)
defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream)
serverA := gomatrixserverlib.ServerName("server.a")
serverAKeyID := gomatrixserverlib.KeyID("ed25519:servera")
serverAPrivKey := test.PrivateKeyA
creator := test.NewUser(t, test.WithSigningServer(serverA, serverAKeyID, serverAPrivKey))
myServer := base.Cfg.Global.ServerName
myServerKeyID := base.Cfg.Global.KeyID
myServerPrivKey := base.Cfg.Global.PrivateKey
joiningUser := test.NewUser(t, test.WithSigningServer(myServer, myServerKeyID, myServerPrivKey))
fmt.Printf("creator: %v joining user: %v\n", creator.ID, joiningUser.ID)
room := test.NewRoom(t, creator)
rsapi := &fedRoomserverAPI{
inputRoomEvents: func(ctx context.Context, req *rsapi.InputRoomEventsRequest, res *rsapi.InputRoomEventsResponse) {
if req.Asynchronous {
t.Errorf("InputRoomEvents from PerformJoin MUST be synchronous")
}
},
queryRoomsForUser: func(ctx context.Context, req *rsapi.QueryRoomsForUserRequest, res *rsapi.QueryRoomsForUserResponse) error {
if req.UserID == joiningUser.ID && req.WantMembership == "join" {
res.RoomIDs = []string{room.ID}
return nil
}
return fmt.Errorf("unexpected queryRoomsForUser: %+v", *req)
},
}
fc := &fedClient{
allowJoins: []*test.Room{room},
t: t,
keys: map[gomatrixserverlib.ServerName]struct {
key ed25519.PrivateKey
keyID gomatrixserverlib.KeyID
}{
serverA: {
key: serverAPrivKey,
keyID: serverAKeyID,
},
myServer: {
key: myServerPrivKey,
keyID: myServerKeyID,
},
},
}
fsapi := federationapi.NewInternalAPI(base, fc, rsapi, base.Caches, nil, false)
var resp api.PerformJoinResponse
fsapi.PerformJoin(context.Background(), &api.PerformJoinRequest{
RoomID: room.ID,
UserID: joiningUser.ID,
ServerNames: []gomatrixserverlib.ServerName{serverA},
}, &resp)
if resp.JoinedVia != serverA {
t.Errorf("PerformJoin: joined via %v want %v", resp.JoinedVia, serverA)
}
if resp.LastError != nil {
t.Fatalf("PerformJoin: returned error: %+v", *resp.LastError)
}
// Inject a keyserver key change event and ensure we try to send it out. If we don't, then the
// federationapi is incorrectly waiting for an output room event to arrive to update the joined
// hosts table.
key := keyapi.DeviceMessage{
Type: keyapi.TypeDeviceKeyUpdate,
DeviceKeys: &keyapi.DeviceKeys{
UserID: joiningUser.ID,
DeviceID: "MY_DEVICE",
DisplayName: "BLARGLE",
KeyJSON: []byte(`{}`),
},
}
b, err := json.Marshal(key)
if err != nil {
t.Fatalf("Failed to marshal device message: %s", err)
}
msg := &nats.Msg{
Subject: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputKeyChangeEvent),
Header: nats.Header{},
Data: b,
}
msg.Header.Set(jetstream.UserID, key.UserID)
testrig.MustPublishMsgs(t, jsctx, msg)
time.Sleep(500 * time.Millisecond)
if !fc.sentTxn {
t.Fatalf("did not send device list update")
}
}
// Tests that event IDs with '/' in them (escaped as %2F) are correctly passed to the right handler and don't 404. // Tests that event IDs with '/' in them (escaped as %2F) are correctly passed to the right handler and don't 404.
// Relevant for v3 rooms and a cause of flakey sytests as the IDs are randomly generated. // Relevant for v3 rooms and a cause of flakey sytests as the IDs are randomly generated.
func TestRoomsV3URLEscapeDoNot404(t *testing.T) { func TestRoomsV3URLEscapeDoNot404(t *testing.T) {
@ -86,7 +318,7 @@ func TestRoomsV3URLEscapeDoNot404(t *testing.T) {
} }
gerr, ok := err.(gomatrix.HTTPError) gerr, ok := err.(gomatrix.HTTPError)
if !ok { if !ok {
t.Errorf("failed to cast response error as gomatrix.HTTPError") t.Errorf("failed to cast response error as gomatrix.HTTPError: %s", err)
continue continue
} }
t.Logf("Error: %+v", gerr) t.Logf("Error: %+v", gerr)

View file

@ -25,8 +25,8 @@ type FederationInternalAPI struct {
db storage.Database db storage.Database
cfg *config.FederationAPI cfg *config.FederationAPI
statistics *statistics.Statistics statistics *statistics.Statistics
rsAPI roomserverAPI.RoomserverInternalAPI rsAPI roomserverAPI.FederationRoomserverAPI
federation *gomatrixserverlib.FederationClient federation api.FederationClient
keyRing *gomatrixserverlib.KeyRing keyRing *gomatrixserverlib.KeyRing
queues *queue.OutgoingQueues queues *queue.OutgoingQueues
joins sync.Map // joins currently in progress joins sync.Map // joins currently in progress
@ -34,8 +34,8 @@ type FederationInternalAPI struct {
func NewFederationInternalAPI( func NewFederationInternalAPI(
db storage.Database, cfg *config.FederationAPI, db storage.Database, cfg *config.FederationAPI,
rsAPI roomserverAPI.RoomserverInternalAPI, rsAPI roomserverAPI.FederationRoomserverAPI,
federation *gomatrixserverlib.FederationClient, federation api.FederationClient,
statistics *statistics.Statistics, statistics *statistics.Statistics,
caches *caching.Caches, caches *caching.Caches,
queues *queue.OutgoingQueues, queues *queue.OutgoingQueues,

View file

@ -8,6 +8,7 @@ import (
"time" "time"
"github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/federationapi/consumers"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/version" "github.com/matrix-org/dendrite/roomserver/version"
"github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrix"
@ -235,6 +236,21 @@ func (r *FederationInternalAPI) performJoinUsingServer(
return fmt.Errorf("respSendJoin.Check: %w", err) return fmt.Errorf("respSendJoin.Check: %w", err)
} }
// We need to immediately update our list of joined hosts for this room now as we are technically
// joined. We must do this synchronously: we cannot rely on the roomserver output events as they
// will happen asyncly. If we don't update this table, you can end up with bad failure modes like
// joining a room, waiting for 200 OK then changing device keys and have those keys not be sent
// to other servers (this was a cause of a flakey sytest "Local device key changes get to remote servers")
// The events are trusted now as we performed auth checks above.
joinedHosts, err := consumers.JoinedHostsFromEvents(respState.StateEvents.TrustedEvents(respMakeJoin.RoomVersion, false))
if err != nil {
return fmt.Errorf("JoinedHostsFromEvents: failed to get joined hosts: %s", err)
}
logrus.WithField("hosts", joinedHosts).WithField("room", roomID).Info("Joined federated room with hosts")
if _, err = r.db.UpdateRoom(context.Background(), roomID, joinedHosts, nil, true); err != nil {
return fmt.Errorf("UpdatedRoom: failed to update room with joined hosts: %s", err)
}
// If we successfully performed a send_join above then the other // If we successfully performed a send_join above then the other
// server now thinks we're a part of the room. Send the newly // server now thinks we're a part of the room. Send the newly
// returned state to the roomserver to update our local view. // returned state to the roomserver to update our local view.
@ -650,7 +666,7 @@ func setDefaultRoomVersionFromJoinEvent(joinEvent gomatrixserverlib.EventBuilder
// FederatedAuthProvider is an auth chain provider which fetches events from the server provided // FederatedAuthProvider is an auth chain provider which fetches events from the server provided
func federatedAuthProvider( func federatedAuthProvider(
ctx context.Context, federation *gomatrixserverlib.FederationClient, ctx context.Context, federation api.FederationClient,
keyRing gomatrixserverlib.JSONVerifier, server gomatrixserverlib.ServerName, keyRing gomatrixserverlib.JSONVerifier, server gomatrixserverlib.ServerName,
) gomatrixserverlib.AuthChainProvider { ) gomatrixserverlib.AuthChainProvider {
// A list of events that we have retried, if they were not included in // A list of events that we have retried, if they were not included in

View file

@ -21,6 +21,7 @@ import (
"sync" "sync"
"time" "time"
fedapi "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/federationapi/statistics" "github.com/matrix-org/dendrite/federationapi/statistics"
"github.com/matrix-org/dendrite/federationapi/storage" "github.com/matrix-org/dendrite/federationapi/storage"
"github.com/matrix-org/dendrite/federationapi/storage/shared" "github.com/matrix-org/dendrite/federationapi/storage/shared"
@ -49,21 +50,21 @@ type destinationQueue struct {
db storage.Database db storage.Database
process *process.ProcessContext process *process.ProcessContext
signing *SigningInfo signing *SigningInfo
rsAPI api.RoomserverInternalAPI rsAPI api.FederationRoomserverAPI
client *gomatrixserverlib.FederationClient // federation client client fedapi.FederationClient // federation client
origin gomatrixserverlib.ServerName // origin of requests origin gomatrixserverlib.ServerName // origin of requests
destination gomatrixserverlib.ServerName // destination of requests destination gomatrixserverlib.ServerName // destination of requests
running atomic.Bool // is the queue worker running? running atomic.Bool // is the queue worker running?
backingOff atomic.Bool // true if we're backing off backingOff atomic.Bool // true if we're backing off
overflowed atomic.Bool // the queues exceed maxPDUsInMemory/maxEDUsInMemory, so we should consult the database for more overflowed atomic.Bool // the queues exceed maxPDUsInMemory/maxEDUsInMemory, so we should consult the database for more
statistics *statistics.ServerStatistics // statistics about this remote server statistics *statistics.ServerStatistics // statistics about this remote server
transactionIDMutex sync.Mutex // protects transactionID transactionIDMutex sync.Mutex // protects transactionID
transactionID gomatrixserverlib.TransactionID // last transaction ID if retrying, or "" if last txn was successful transactionID gomatrixserverlib.TransactionID // last transaction ID if retrying, or "" if last txn was successful
notify chan struct{} // interrupts idle wait pending PDUs/EDUs notify chan struct{} // interrupts idle wait pending PDUs/EDUs
pendingPDUs []*queuedPDU // PDUs waiting to be sent pendingPDUs []*queuedPDU // PDUs waiting to be sent
pendingEDUs []*queuedEDU // EDUs waiting to be sent pendingEDUs []*queuedEDU // EDUs waiting to be sent
pendingMutex sync.RWMutex // protects pendingPDUs and pendingEDUs pendingMutex sync.RWMutex // protects pendingPDUs and pendingEDUs
interruptBackoff chan bool // interrupts backoff interruptBackoff chan bool // interrupts backoff
} }
// Send event adds the event to the pending queue for the destination. // Send event adds the event to the pending queue for the destination.

View file

@ -26,6 +26,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
fedapi "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/federationapi/statistics" "github.com/matrix-org/dendrite/federationapi/statistics"
"github.com/matrix-org/dendrite/federationapi/storage" "github.com/matrix-org/dendrite/federationapi/storage"
"github.com/matrix-org/dendrite/federationapi/storage/shared" "github.com/matrix-org/dendrite/federationapi/storage/shared"
@ -39,9 +40,9 @@ type OutgoingQueues struct {
db storage.Database db storage.Database
process *process.ProcessContext process *process.ProcessContext
disabled bool disabled bool
rsAPI api.RoomserverInternalAPI rsAPI api.FederationRoomserverAPI
origin gomatrixserverlib.ServerName origin gomatrixserverlib.ServerName
client *gomatrixserverlib.FederationClient client fedapi.FederationClient
statistics *statistics.Statistics statistics *statistics.Statistics
signing *SigningInfo signing *SigningInfo
queuesMutex sync.Mutex // protects the below queuesMutex sync.Mutex // protects the below
@ -85,8 +86,8 @@ func NewOutgoingQueues(
process *process.ProcessContext, process *process.ProcessContext,
disabled bool, disabled bool,
origin gomatrixserverlib.ServerName, origin gomatrixserverlib.ServerName,
client *gomatrixserverlib.FederationClient, client fedapi.FederationClient,
rsAPI api.RoomserverInternalAPI, rsAPI api.FederationRoomserverAPI,
statistics *statistics.Statistics, statistics *statistics.Statistics,
signing *SigningInfo, signing *SigningInfo,
) *OutgoingQueues { ) *OutgoingQueues {

View file

@ -30,7 +30,7 @@ import (
// RoomAliasToID converts the queried alias into a room ID and returns it // RoomAliasToID converts the queried alias into a room ID and returns it
func RoomAliasToID( func RoomAliasToID(
httpReq *http.Request, httpReq *http.Request,
federation *gomatrixserverlib.FederationClient, federation federationAPI.FederationClient,
cfg *config.FederationAPI, cfg *config.FederationAPI,
rsAPI roomserverAPI.FederationRoomserverAPI, rsAPI roomserverAPI.FederationRoomserverAPI,
senderAPI federationAPI.FederationInternalAPI, senderAPI federationAPI.FederationInternalAPI,

View file

@ -54,7 +54,7 @@ func Setup(
rsAPI roomserverAPI.FederationRoomserverAPI, rsAPI roomserverAPI.FederationRoomserverAPI,
fsAPI *fedInternal.FederationInternalAPI, fsAPI *fedInternal.FederationInternalAPI,
keys gomatrixserverlib.JSONVerifier, keys gomatrixserverlib.JSONVerifier,
federation *gomatrixserverlib.FederationClient, federation federationAPI.FederationClient,
userAPI userapi.FederationUserAPI, userAPI userapi.FederationUserAPI,
keyAPI keyserverAPI.FederationKeyAPI, keyAPI keyserverAPI.FederationKeyAPI,
mscCfg *config.MSCs, mscCfg *config.MSCs,

View file

@ -85,7 +85,7 @@ func Send(
rsAPI api.FederationRoomserverAPI, rsAPI api.FederationRoomserverAPI,
keyAPI keyapi.FederationKeyAPI, keyAPI keyapi.FederationKeyAPI,
keys gomatrixserverlib.JSONVerifier, keys gomatrixserverlib.JSONVerifier,
federation *gomatrixserverlib.FederationClient, federation federationAPI.FederationClient,
mu *internal.MutexByRoom, mu *internal.MutexByRoom,
servers federationAPI.ServersInRoomProvider, servers federationAPI.ServersInRoomProvider,
producer *producers.SyncAPIProducer, producer *producers.SyncAPIProducer,

View file

@ -8,8 +8,8 @@ import (
"time" "time"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/test"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )

View file

@ -23,6 +23,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
federationAPI "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
@ -57,7 +58,7 @@ var (
func CreateInvitesFrom3PIDInvites( func CreateInvitesFrom3PIDInvites(
req *http.Request, rsAPI api.FederationRoomserverAPI, req *http.Request, rsAPI api.FederationRoomserverAPI,
cfg *config.FederationAPI, cfg *config.FederationAPI,
federation *gomatrixserverlib.FederationClient, federation federationAPI.FederationClient,
userAPI userapi.FederationUserAPI, userAPI userapi.FederationUserAPI,
) util.JSONResponse { ) util.JSONResponse {
var body invites var body invites
@ -107,7 +108,7 @@ func ExchangeThirdPartyInvite(
roomID string, roomID string,
rsAPI api.FederationRoomserverAPI, rsAPI api.FederationRoomserverAPI,
cfg *config.FederationAPI, cfg *config.FederationAPI,
federation *gomatrixserverlib.FederationClient, federation federationAPI.FederationClient,
) util.JSONResponse { ) util.JSONResponse {
var builder gomatrixserverlib.EventBuilder var builder gomatrixserverlib.EventBuilder
if err := json.Unmarshal(request.Content(), &builder); err != nil { if err := json.Unmarshal(request.Content(), &builder); err != nil {
@ -165,7 +166,12 @@ func ExchangeThirdPartyInvite(
// Ask the requesting server to sign the newly created event so we know it // Ask the requesting server to sign the newly created event so we know it
// acknowledged it // acknowledged it
signedEvent, err := federation.SendInvite(httpReq.Context(), request.Origin(), event) inviteReq, err := gomatrixserverlib.NewInviteV2Request(event.Headered(verRes.RoomVersion), nil)
if err != nil {
util.GetLogger(httpReq.Context()).WithError(err).Error("failed to make invite v2 request")
return jsonerror.InternalServerError()
}
signedEvent, err := federation.SendInviteV2(httpReq.Context(), request.Origin(), inviteReq)
if err != nil { if err != nil {
util.GetLogger(httpReq.Context()).WithError(err).Error("federation.SendInvite failed") util.GetLogger(httpReq.Context()).WithError(err).Error("federation.SendInvite failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
@ -205,7 +211,7 @@ func ExchangeThirdPartyInvite(
func createInviteFrom3PIDInvite( func createInviteFrom3PIDInvite(
ctx context.Context, rsAPI api.FederationRoomserverAPI, ctx context.Context, rsAPI api.FederationRoomserverAPI,
cfg *config.FederationAPI, cfg *config.FederationAPI,
inv invite, federation *gomatrixserverlib.FederationClient, inv invite, federation federationAPI.FederationClient,
userAPI userapi.FederationUserAPI, userAPI userapi.FederationUserAPI,
) (*gomatrixserverlib.Event, error) { ) (*gomatrixserverlib.Event, error) {
verReq := api.QueryRoomVersionForRoomRequest{RoomID: inv.RoomID} verReq := api.QueryRoomVersionForRoomRequest{RoomID: inv.RoomID}
@ -335,7 +341,7 @@ func buildMembershipEvent(
// them responded with an error. // them responded with an error.
func sendToRemoteServer( func sendToRemoteServer(
ctx context.Context, inv invite, ctx context.Context, inv invite,
federation *gomatrixserverlib.FederationClient, _ *config.FederationAPI, federation federationAPI.FederationClient, _ *config.FederationAPI,
builder gomatrixserverlib.EventBuilder, builder gomatrixserverlib.EventBuilder,
) (err error) { ) (err error) {
remoteServers := make([]gomatrixserverlib.ServerName, 2) remoteServers := make([]gomatrixserverlib.ServerName, 2)

View file

@ -25,13 +25,12 @@ import (
type Database interface { type Database interface {
gomatrixserverlib.KeyDatabase gomatrixserverlib.KeyDatabase
UpdateRoom(ctx context.Context, roomID, oldEventID, newEventID string, addHosts []types.JoinedHost, removeHosts []string) (joinedHosts []types.JoinedHost, err error) UpdateRoom(ctx context.Context, roomID string, addHosts []types.JoinedHost, removeHosts []string, purgeRoomFirst bool) (joinedHosts []types.JoinedHost, err error)
GetJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error) GetJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error)
GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
// GetJoinedHostsForRooms returns the complete set of servers in the rooms given. // GetJoinedHostsForRooms returns the complete set of servers in the rooms given.
GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf bool) ([]gomatrixserverlib.ServerName, error) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf bool) ([]gomatrixserverlib.ServerName, error)
PurgeRoomState(ctx context.Context, roomID string) error
StoreJSON(ctx context.Context, js string) (*shared.Receipt, error) StoreJSON(ctx context.Context, js string) (*shared.Receipt, error)

View file

@ -63,11 +63,21 @@ func (r *Receipt) String() string {
// this isn't a duplicate message. // this isn't a duplicate message.
func (d *Database) UpdateRoom( func (d *Database) UpdateRoom(
ctx context.Context, ctx context.Context,
roomID, oldEventID, newEventID string, roomID string,
addHosts []types.JoinedHost, addHosts []types.JoinedHost,
removeHosts []string, removeHosts []string,
purgeRoomFirst bool,
) (joinedHosts []types.JoinedHost, err error) { ) (joinedHosts []types.JoinedHost, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
if purgeRoomFirst {
// If the event is a create event then we'll delete all of the existing
// data for the room. The only reason that a create event would be replayed
// to us in this way is if we're about to receive the entire room state.
if err = d.FederationJoinedHosts.DeleteJoinedHostsForRoom(ctx, txn, roomID); err != nil {
return fmt.Errorf("d.FederationJoinedHosts.DeleteJoinedHosts: %w", err)
}
}
joinedHosts, err = d.FederationJoinedHosts.SelectJoinedHostsWithTx(ctx, txn, roomID) joinedHosts, err = d.FederationJoinedHosts.SelectJoinedHostsWithTx(ctx, txn, roomID)
if err != nil { if err != nil {
return err return err
@ -138,20 +148,6 @@ func (d *Database) StoreJSON(
}, nil }, nil
} }
func (d *Database) PurgeRoomState(
ctx context.Context, roomID string,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
// If the event is a create event then we'll delete all of the existing
// data for the room. The only reason that a create event would be replayed
// to us in this way is if we're about to receive the entire room state.
if err := d.FederationJoinedHosts.DeleteJoinedHostsForRoom(ctx, txn, roomID); err != nil {
return fmt.Errorf("d.FederationJoinedHosts.DeleteJoinedHosts: %w", err)
}
return nil
})
}
func (d *Database) AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error { func (d *Database) AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationBlacklist.InsertBlacklist(context.TODO(), txn, serverName) return d.FederationBlacklist.InsertBlacklist(context.TODO(), txn, serverName)

View file

@ -20,7 +20,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/matrix-org/dendrite/internal/test" "github.com/matrix-org/dendrite/test"
) )
func TestEDUCache(t *testing.T) { func TestEDUCache(t *testing.T) {

View file

@ -1,158 +0,0 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package test
import (
"crypto/tls"
"fmt"
"io"
"io/ioutil"
"net/http"
"sync"
"time"
"github.com/matrix-org/gomatrixserverlib"
)
// Request contains the information necessary to issue a request and test its result
type Request struct {
Req *http.Request
WantedBody string
WantedStatusCode int
LastErr *LastRequestErr
}
// LastRequestErr is a synchronised error wrapper
// Useful for obtaining the last error from a set of requests
type LastRequestErr struct {
sync.Mutex
Err error
}
// Set sets the error
func (r *LastRequestErr) Set(err error) {
r.Lock()
defer r.Unlock()
r.Err = err
}
// Get gets the error
func (r *LastRequestErr) Get() error {
r.Lock()
defer r.Unlock()
return r.Err
}
// CanonicalJSONInput canonicalises a slice of JSON strings
// Useful for test input
func CanonicalJSONInput(jsonData []string) []string {
for i := range jsonData {
jsonBytes, err := gomatrixserverlib.CanonicalJSON([]byte(jsonData[i]))
if err != nil && err != io.EOF {
panic(err)
}
jsonData[i] = string(jsonBytes)
}
return jsonData
}
// Do issues a request and checks the status code and body of the response
func (r *Request) Do() (err error) {
client := &http.Client{
Timeout: 5 * time.Second,
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
},
}
res, err := client.Do(r.Req)
if err != nil {
return err
}
defer (func() { err = res.Body.Close() })()
if res.StatusCode != r.WantedStatusCode {
return fmt.Errorf("incorrect status code. Expected: %d Got: %d", r.WantedStatusCode, res.StatusCode)
}
if r.WantedBody != "" {
resBytes, err := ioutil.ReadAll(res.Body)
if err != nil {
return err
}
jsonBytes, err := gomatrixserverlib.CanonicalJSON(resBytes)
if err != nil {
return err
}
if string(jsonBytes) != r.WantedBody {
return fmt.Errorf("returned wrong bytes. Expected:\n%s\n\nGot:\n%s", r.WantedBody, string(jsonBytes))
}
}
return nil
}
// DoUntilSuccess blocks and repeats the same request until the response returns the desired status code and body.
// It then closes the given channel and returns.
func (r *Request) DoUntilSuccess(done chan error) {
r.LastErr = &LastRequestErr{}
for {
if err := r.Do(); err != nil {
r.LastErr.Set(err)
time.Sleep(1 * time.Second) // don't tightloop
continue
}
close(done)
return
}
}
// Run repeatedly issues a request until success, error or a timeout is reached
func (r *Request) Run(label string, timeout time.Duration, serverCmdChan chan error) {
fmt.Printf("==TESTING== %v (timeout: %v)\n", label, timeout)
done := make(chan error, 1)
// We need to wait for the server to:
// - have connected to the database
// - have created the tables
// - be listening on the given port
go r.DoUntilSuccess(done)
// wait for one of:
// - the test to pass (done channel is closed)
// - the server to exit with an error (error sent on serverCmdChan)
// - our test timeout to expire
// We don't need to clean up since the main() function handles that in the event we panic
select {
case <-time.After(timeout):
fmt.Printf("==TESTING== %v TIMEOUT\n", label)
if reqErr := r.LastErr.Get(); reqErr != nil {
fmt.Println("Last /sync request error:")
fmt.Println(reqErr)
}
panic(fmt.Sprintf("%v server timed out", label))
case err := <-serverCmdChan:
if err != nil {
fmt.Println("=============================================================================================")
fmt.Printf("%v server failed to run. If failing with 'pq: password authentication failed for user' try:", label)
fmt.Println(" export PGHOST=/var/run/postgresql")
fmt.Println("=============================================================================================")
panic(err)
}
case <-done:
fmt.Printf("==TESTING== %v PASSED\n", label)
}
}

View file

@ -1,76 +0,0 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package test
import (
"io"
"os/exec"
"path/filepath"
"strings"
)
// KafkaExecutor executes kafka scripts.
type KafkaExecutor struct {
// The location of Zookeeper. Typically this is `localhost:2181`.
ZookeeperURI string
// The directory where Kafka is installed to. Used to locate kafka scripts.
KafkaDirectory string
// The location of the Kafka logs. Typically this is `localhost:9092`.
KafkaURI string
// Where stdout and stderr should be written to. Typically this is `os.Stderr`.
OutputWriter io.Writer
}
// CreateTopic creates a new kafka topic. This is created with a single partition.
func (e *KafkaExecutor) CreateTopic(topic string) error {
cmd := exec.Command(
filepath.Join(e.KafkaDirectory, "bin", "kafka-topics.sh"),
"--create",
"--zookeeper", e.ZookeeperURI,
"--replication-factor", "1",
"--partitions", "1",
"--topic", topic,
)
cmd.Stdout = e.OutputWriter
cmd.Stderr = e.OutputWriter
return cmd.Run()
}
// WriteToTopic writes data to a kafka topic.
func (e *KafkaExecutor) WriteToTopic(topic string, data []string) error {
cmd := exec.Command(
filepath.Join(e.KafkaDirectory, "bin", "kafka-console-producer.sh"),
"--broker-list", e.KafkaURI,
"--topic", topic,
)
cmd.Stdout = e.OutputWriter
cmd.Stderr = e.OutputWriter
cmd.Stdin = strings.NewReader(strings.Join(data, "\n"))
return cmd.Run()
}
// DeleteTopic deletes a given kafka topic if it exists.
func (e *KafkaExecutor) DeleteTopic(topic string) error {
cmd := exec.Command(
filepath.Join(e.KafkaDirectory, "bin", "kafka-topics.sh"),
"--delete",
"--if-exists",
"--zookeeper", e.ZookeeperURI,
"--topic", topic,
)
cmd.Stderr = e.OutputWriter
cmd.Stdout = e.OutputWriter
return cmd.Run()
}

View file

@ -1,152 +0,0 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package test
import (
"context"
"fmt"
"net"
"net/http"
"os"
"os/exec"
"path/filepath"
"strings"
"sync"
"testing"
"github.com/matrix-org/dendrite/setup/config"
)
// Defaulting allows assignment of string variables with a fallback default value
// Useful for use with os.Getenv() for example
func Defaulting(value, defaultValue string) string {
if value == "" {
value = defaultValue
}
return value
}
// CreateDatabase creates a new database, dropping it first if it exists
func CreateDatabase(command string, args []string, database string) error {
cmd := exec.Command(command, args...)
cmd.Stdin = strings.NewReader(
fmt.Sprintf("DROP DATABASE IF EXISTS %s; CREATE DATABASE %s;", database, database),
)
// Send stdout and stderr to our stderr so that we see error messages from
// the psql process
cmd.Stdout = os.Stderr
cmd.Stderr = os.Stderr
return cmd.Run()
}
// CreateBackgroundCommand creates an executable command
// The Cmd being executed is returned. A channel is also returned,
// which will have any termination errors sent down it, followed immediately by the channel being closed.
func CreateBackgroundCommand(command string, args []string) (*exec.Cmd, chan error) {
cmd := exec.Command(command, args...)
cmd.Stderr = os.Stderr
cmd.Stdout = os.Stderr
if err := cmd.Start(); err != nil {
panic("failed to start server: " + err.Error())
}
cmdChan := make(chan error, 1)
go func() {
cmdChan <- cmd.Wait()
close(cmdChan)
}()
return cmd, cmdChan
}
// InitDatabase creates the database and config file needed for the server to run
func InitDatabase(postgresDatabase, postgresContainerName string, databases []string) {
if len(databases) > 0 {
var dbCmd string
var dbArgs []string
if postgresContainerName == "" {
dbCmd = "psql"
dbArgs = []string{postgresDatabase}
} else {
dbCmd = "docker"
dbArgs = []string{
"exec", "-i", postgresContainerName, "psql", "-U", "postgres", postgresDatabase,
}
}
for _, database := range databases {
if err := CreateDatabase(dbCmd, dbArgs, database); err != nil {
panic(err)
}
}
}
}
// StartProxy creates a reverse proxy
func StartProxy(bindAddr string, cfg *config.Dendrite) (*exec.Cmd, chan error) {
proxyArgs := []string{
"--bind-address", bindAddr,
"--sync-api-server-url", "http://" + string(cfg.SyncAPI.InternalAPI.Connect),
"--client-api-server-url", "http://" + string(cfg.ClientAPI.InternalAPI.Connect),
"--media-api-server-url", "http://" + string(cfg.MediaAPI.InternalAPI.Connect),
"--tls-cert", "server.crt",
"--tls-key", "server.key",
}
return CreateBackgroundCommand(
filepath.Join(filepath.Dir(os.Args[0]), "client-api-proxy"),
proxyArgs,
)
}
// ListenAndServe will listen on a random high-numbered port and attach the given router.
// Returns the base URL to send requests to. Call `cancel` to shutdown the server, which will block until it has closed.
func ListenAndServe(t *testing.T, router http.Handler, useTLS bool) (apiURL string, cancel func()) {
listener, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatalf("failed to listen: %s", err)
}
port := listener.Addr().(*net.TCPAddr).Port
srv := http.Server{}
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
srv.Handler = router
var err error
if useTLS {
certFile := filepath.Join(os.TempDir(), "dendrite.cert")
keyFile := filepath.Join(os.TempDir(), "dendrite.key")
err = NewTLSKey(keyFile, certFile)
if err != nil {
t.Logf("failed to generate tls key/cert: %s", err)
return
}
err = srv.ServeTLS(listener, certFile, keyFile)
} else {
err = srv.Serve(listener)
}
if err != nil && err != http.ErrServerClosed {
t.Logf("Listen failed: %s", err)
}
}()
secure := ""
if useTLS {
secure = "s"
}
return fmt.Sprintf("http%s://localhost:%d", secure, port), func() {
_ = srv.Shutdown(context.Background())
wg.Wait()
}
}

View file

@ -84,7 +84,7 @@ type DeviceListUpdater struct {
db DeviceListUpdaterDatabase db DeviceListUpdaterDatabase
api DeviceListUpdaterAPI api DeviceListUpdaterAPI
producer KeyChangeProducer producer KeyChangeProducer
fedClient fedsenderapi.FederationClient fedClient fedsenderapi.KeyserverFederationAPI
workerChans []chan gomatrixserverlib.ServerName workerChans []chan gomatrixserverlib.ServerName
// When device lists are stale for a user, they get inserted into this map with a channel which `Update` will // When device lists are stale for a user, they get inserted into this map with a channel which `Update` will
@ -127,7 +127,7 @@ type KeyChangeProducer interface {
// NewDeviceListUpdater creates a new updater which fetches fresh device lists when they go stale. // NewDeviceListUpdater creates a new updater which fetches fresh device lists when they go stale.
func NewDeviceListUpdater( func NewDeviceListUpdater(
db DeviceListUpdaterDatabase, api DeviceListUpdaterAPI, producer KeyChangeProducer, db DeviceListUpdaterDatabase, api DeviceListUpdaterAPI, producer KeyChangeProducer,
fedClient fedsenderapi.FederationClient, numWorkers int, fedClient fedsenderapi.KeyserverFederationAPI, numWorkers int,
) *DeviceListUpdater { ) *DeviceListUpdater {
return &DeviceListUpdater{ return &DeviceListUpdater{
userIDToMutex: make(map[string]*sync.Mutex), userIDToMutex: make(map[string]*sync.Mutex),

View file

@ -37,7 +37,7 @@ import (
type KeyInternalAPI struct { type KeyInternalAPI struct {
DB storage.Database DB storage.Database
ThisServer gomatrixserverlib.ServerName ThisServer gomatrixserverlib.ServerName
FedClient fedsenderapi.FederationClient FedClient fedsenderapi.KeyserverFederationAPI
UserAPI userapi.KeyserverUserAPI UserAPI userapi.KeyserverUserAPI
Producer *producers.KeyChange Producer *producers.KeyChange
Updater *DeviceListUpdater Updater *DeviceListUpdater

View file

@ -37,7 +37,7 @@ func AddInternalRoutes(router *mux.Router, intAPI api.KeyInternalAPI) {
// NewInternalAPI returns a concerete implementation of the internal API. Callers // NewInternalAPI returns a concerete implementation of the internal API. Callers
// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. // can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes.
func NewInternalAPI( func NewInternalAPI(
base *base.BaseDendrite, cfg *config.KeyServer, fedClient fedsenderapi.FederationClient, base *base.BaseDendrite, cfg *config.KeyServer, fedClient fedsenderapi.KeyserverFederationAPI,
) api.KeyInternalAPI { ) api.KeyInternalAPI {
js, _ := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) js, _ := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream)

View file

@ -183,6 +183,7 @@ type FederationRoomserverAPI interface {
QueryMissingEvents(ctx context.Context, req *QueryMissingEventsRequest, res *QueryMissingEventsResponse) error QueryMissingEvents(ctx context.Context, req *QueryMissingEventsRequest, res *QueryMissingEventsResponse) error
// Query whether a server is allowed to see an event // Query whether a server is allowed to see an event
QueryServerAllowedToSeeEvent(ctx context.Context, req *QueryServerAllowedToSeeEventRequest, res *QueryServerAllowedToSeeEventResponse) error QueryServerAllowedToSeeEvent(ctx context.Context, req *QueryServerAllowedToSeeEventRequest, res *QueryServerAllowedToSeeEventResponse) error
QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error
PerformInboundPeek(ctx context.Context, req *PerformInboundPeekRequest, res *PerformInboundPeekResponse) error PerformInboundPeek(ctx context.Context, req *PerformInboundPeekRequest, res *PerformInboundPeekResponse) error
PerformInvite(ctx context.Context, req *PerformInviteRequest, res *PerformInviteResponse) error PerformInvite(ctx context.Context, req *PerformInviteRequest, res *PerformInviteResponse) error
// Query a given amount (or less) of events prior to a given set of events. // Query a given amount (or less) of events prior to a given set of events.

View file

@ -12,7 +12,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test/testrig"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go" "github.com/nats-io/nats.go"
) )
@ -22,7 +22,7 @@ var jc *nats.Conn
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
var b *base.BaseDendrite var b *base.BaseDendrite
b, js, jc = test.Base(nil) b, js, jc = testrig.Base(nil)
code := m.Run() code := m.Run()
b.ShutdownDendrite() b.ShutdownDendrite()
b.WaitForComponentsToFinish() b.WaitForComponentsToFinish()

View file

@ -19,8 +19,8 @@ import (
"encoding/json" "encoding/json"
"testing" "testing"
"github.com/matrix-org/dendrite/internal/test"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )

View file

@ -264,11 +264,11 @@ func (s *eventStatements) BulkSelectStateEventByNID(
ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
stateKeyTuples []types.StateKeyTuple, stateKeyTuples []types.StateKeyTuple,
) ([]types.StateEntry, error) { ) ([]types.StateEntry, error) {
tuples := stateKeyTupleSorter(stateKeyTuples) tuples := types.StateKeyTupleSorter(stateKeyTuples)
sort.Sort(tuples) sort.Sort(tuples)
eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays() eventTypeNIDArray, eventStateKeyNIDArray := tuples.TypesAndStateKeysAsArrays()
stmt := sqlutil.TxStmt(txn, s.bulkSelectStateEventByNIDStmt) stmt := sqlutil.TxStmt(txn, s.bulkSelectStateEventByNIDStmt)
rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs), eventTypeNIDArray, eventStateKeyNIDArray) rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs), pq.Int64Array(eventTypeNIDArray), pq.Int64Array(eventStateKeyNIDArray))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -61,12 +61,12 @@ type roomAliasesStatements struct {
deleteRoomAliasStmt *sql.Stmt deleteRoomAliasStmt *sql.Stmt
} }
func createRoomAliasesTable(db *sql.DB) error { func CreateRoomAliasesTable(db *sql.DB) error {
_, err := db.Exec(roomAliasesSchema) _, err := db.Exec(roomAliasesSchema)
return err return err
} }
func prepareRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) { func PrepareRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) {
s := &roomAliasesStatements{} s := &roomAliasesStatements{}
return s, sqlutil.StatementList{ return s, sqlutil.StatementList{
@ -108,8 +108,8 @@ func (s *roomAliasesStatements) SelectAliasesFromRoomID(
defer internal.CloseAndLogIfError(ctx, rows, "selectAliasesFromRoomID: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "selectAliasesFromRoomID: rows.close() failed")
var aliases []string var aliases []string
var alias string
for rows.Next() { for rows.Next() {
var alias string
if err = rows.Scan(&alias); err != nil { if err = rows.Scan(&alias); err != nil {
return nil, err return nil, err
} }

View file

@ -95,12 +95,12 @@ type roomStatements struct {
bulkSelectRoomNIDsStmt *sql.Stmt bulkSelectRoomNIDsStmt *sql.Stmt
} }
func createRoomsTable(db *sql.DB) error { func CreateRoomsTable(db *sql.DB) error {
_, err := db.Exec(roomsSchema) _, err := db.Exec(roomsSchema)
return err return err
} }
func prepareRoomsTable(db *sql.DB) (tables.Rooms, error) { func PrepareRoomsTable(db *sql.DB) (tables.Rooms, error) {
s := &roomStatements{} s := &roomStatements{}
return s, sqlutil.StatementList{ return s, sqlutil.StatementList{
@ -117,7 +117,7 @@ func prepareRoomsTable(db *sql.DB) (tables.Rooms, error) {
}.Prepare(db) }.Prepare(db)
} }
func (s *roomStatements) SelectRoomIDs(ctx context.Context, txn *sql.Tx) ([]string, error) { func (s *roomStatements) SelectRoomIDsWithEvents(ctx context.Context, txn *sql.Tx) ([]string, error) {
stmt := sqlutil.TxStmt(txn, s.selectRoomIDsStmt) stmt := sqlutil.TxStmt(txn, s.selectRoomIDsStmt)
rows, err := stmt.QueryContext(ctx) rows, err := stmt.QueryContext(ctx)
if err != nil { if err != nil {
@ -125,8 +125,8 @@ func (s *roomStatements) SelectRoomIDs(ctx context.Context, txn *sql.Tx) ([]stri
} }
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomIDsStmt: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "selectRoomIDsStmt: rows.close() failed")
var roomIDs []string var roomIDs []string
var roomID string
for rows.Next() { for rows.Next() {
var roomID string
if err = rows.Scan(&roomID); err != nil { if err = rows.Scan(&roomID); err != nil {
return nil, err return nil, err
} }
@ -231,9 +231,9 @@ func (s *roomStatements) SelectRoomVersionsForRoomNIDs(
} }
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomVersionsForRoomNIDsStmt: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "selectRoomVersionsForRoomNIDsStmt: rows.close() failed")
result := make(map[types.RoomNID]gomatrixserverlib.RoomVersion) result := make(map[types.RoomNID]gomatrixserverlib.RoomVersion)
var roomNID types.RoomNID
var roomVersion gomatrixserverlib.RoomVersion
for rows.Next() { for rows.Next() {
var roomNID types.RoomNID
var roomVersion gomatrixserverlib.RoomVersion
if err = rows.Scan(&roomNID, &roomVersion); err != nil { if err = rows.Scan(&roomNID, &roomVersion); err != nil {
return nil, err return nil, err
} }
@ -254,8 +254,8 @@ func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, txn *sql.Tx, roo
} }
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomIDsStmt: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomIDsStmt: rows.close() failed")
var roomIDs []string var roomIDs []string
var roomID string
for rows.Next() { for rows.Next() {
var roomID string
if err = rows.Scan(&roomID); err != nil { if err = rows.Scan(&roomID); err != nil {
return nil, err return nil, err
} }
@ -276,8 +276,8 @@ func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, txn *sql.Tx, ro
} }
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomNIDsStmt: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomNIDsStmt: rows.close() failed")
var roomNIDs []types.RoomNID var roomNIDs []types.RoomNID
var roomNID types.RoomNID
for rows.Next() { for rows.Next() {
var roomNID types.RoomNID
if err = rows.Scan(&roomNID); err != nil { if err = rows.Scan(&roomNID); err != nil {
return nil, err return nil, err
} }

View file

@ -19,7 +19,6 @@ import (
"context" "context"
"database/sql" "database/sql"
"fmt" "fmt"
"sort"
"github.com/lib/pq" "github.com/lib/pq"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
@ -71,12 +70,12 @@ type stateBlockStatements struct {
bulkSelectStateBlockEntriesStmt *sql.Stmt bulkSelectStateBlockEntriesStmt *sql.Stmt
} }
func createStateBlockTable(db *sql.DB) error { func CreateStateBlockTable(db *sql.DB) error {
_, err := db.Exec(stateDataSchema) _, err := db.Exec(stateDataSchema)
return err return err
} }
func prepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) { func PrepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) {
s := &stateBlockStatements{} s := &stateBlockStatements{}
return s, sqlutil.StatementList{ return s, sqlutil.StatementList{
@ -90,9 +89,9 @@ func (s *stateBlockStatements) BulkInsertStateData(
entries types.StateEntries, entries types.StateEntries,
) (id types.StateBlockNID, err error) { ) (id types.StateBlockNID, err error) {
entries = entries[:util.SortAndUnique(entries)] entries = entries[:util.SortAndUnique(entries)]
var nids types.EventNIDs nids := make(types.EventNIDs, entries.Len())
for _, e := range entries { for i := range entries {
nids = append(nids, e.EventNID) nids[i] = entries[i].EventNID
} }
stmt := sqlutil.TxStmt(txn, s.insertStateDataStmt) stmt := sqlutil.TxStmt(txn, s.insertStateDataStmt)
err = stmt.QueryRowContext( err = stmt.QueryRowContext(
@ -113,15 +112,15 @@ func (s *stateBlockStatements) BulkSelectStateBlockEntries(
results := make([][]types.EventNID, len(stateBlockNIDs)) results := make([][]types.EventNID, len(stateBlockNIDs))
i := 0 i := 0
var stateBlockNID types.StateBlockNID
var result pq.Int64Array
for ; rows.Next(); i++ { for ; rows.Next(); i++ {
var stateBlockNID types.StateBlockNID
var result pq.Int64Array
if err = rows.Scan(&stateBlockNID, &result); err != nil { if err = rows.Scan(&stateBlockNID, &result); err != nil {
return nil, err return nil, err
} }
r := []types.EventNID{} r := make([]types.EventNID, len(result))
for _, e := range result { for x := range result {
r = append(r, types.EventNID(e)) r[x] = types.EventNID(result[x])
} }
results[i] = r results[i] = r
} }
@ -141,35 +140,3 @@ func stateBlockNIDsAsArray(stateBlockNIDs []types.StateBlockNID) pq.Int64Array {
} }
return pq.Int64Array(nids) return pq.Int64Array(nids)
} }
type stateKeyTupleSorter []types.StateKeyTuple
func (s stateKeyTupleSorter) Len() int { return len(s) }
func (s stateKeyTupleSorter) Less(i, j int) bool { return s[i].LessThan(s[j]) }
func (s stateKeyTupleSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
// Check whether a tuple is in the list. Assumes that the list is sorted.
func (s stateKeyTupleSorter) contains(value types.StateKeyTuple) bool {
i := sort.Search(len(s), func(i int) bool { return !s[i].LessThan(value) })
return i < len(s) && s[i] == value
}
// List the unique eventTypeNIDs and eventStateKeyNIDs.
// Assumes that the list is sorted.
func (s stateKeyTupleSorter) typesAndStateKeysAsArrays() (eventTypeNIDs pq.Int64Array, eventStateKeyNIDs pq.Int64Array) {
eventTypeNIDs = make(pq.Int64Array, len(s))
eventStateKeyNIDs = make(pq.Int64Array, len(s))
for i := range s {
eventTypeNIDs[i] = int64(s[i].EventTypeNID)
eventStateKeyNIDs[i] = int64(s[i].EventStateKeyNID)
}
eventTypeNIDs = eventTypeNIDs[:util.SortAndUnique(int64Sorter(eventTypeNIDs))]
eventStateKeyNIDs = eventStateKeyNIDs[:util.SortAndUnique(int64Sorter(eventStateKeyNIDs))]
return
}
type int64Sorter []int64
func (s int64Sorter) Len() int { return len(s) }
func (s int64Sorter) Less(i, j int) bool { return s[i] < s[j] }
func (s int64Sorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }

View file

@ -1,86 +0,0 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"sort"
"testing"
"github.com/matrix-org/dendrite/roomserver/types"
)
func TestStateKeyTupleSorter(t *testing.T) {
input := stateKeyTupleSorter{
{EventTypeNID: 1, EventStateKeyNID: 2},
{EventTypeNID: 1, EventStateKeyNID: 4},
{EventTypeNID: 2, EventStateKeyNID: 2},
{EventTypeNID: 1, EventStateKeyNID: 1},
}
want := []types.StateKeyTuple{
{EventTypeNID: 1, EventStateKeyNID: 1},
{EventTypeNID: 1, EventStateKeyNID: 2},
{EventTypeNID: 1, EventStateKeyNID: 4},
{EventTypeNID: 2, EventStateKeyNID: 2},
}
doNotWant := []types.StateKeyTuple{
{EventTypeNID: 0, EventStateKeyNID: 0},
{EventTypeNID: 1, EventStateKeyNID: 3},
{EventTypeNID: 2, EventStateKeyNID: 1},
{EventTypeNID: 3, EventStateKeyNID: 1},
}
wantTypeNIDs := []int64{1, 2}
wantStateKeyNIDs := []int64{1, 2, 4}
// Sort the input and check it's in the right order.
sort.Sort(input)
gotTypeNIDs, gotStateKeyNIDs := input.typesAndStateKeysAsArrays()
for i := range want {
if input[i] != want[i] {
t.Errorf("Wanted %#v at index %d got %#v", want[i], i, input[i])
}
if !input.contains(want[i]) {
t.Errorf("Wanted %#v.contains(%#v) to be true but got false", input, want[i])
}
}
for i := range doNotWant {
if input.contains(doNotWant[i]) {
t.Errorf("Wanted %#v.contains(%#v) to be false but got true", input, doNotWant[i])
}
}
if len(wantTypeNIDs) != len(gotTypeNIDs) {
t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
}
for i := range wantTypeNIDs {
if wantTypeNIDs[i] != gotTypeNIDs[i] {
t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
}
}
if len(wantStateKeyNIDs) != len(gotStateKeyNIDs) {
t.Fatalf("Wanted state key NIDs %#v got %#v", wantStateKeyNIDs, gotStateKeyNIDs)
}
for i := range wantStateKeyNIDs {
if wantStateKeyNIDs[i] != gotStateKeyNIDs[i] {
t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
}
}
}

View file

@ -77,12 +77,12 @@ type stateSnapshotStatements struct {
bulkSelectStateBlockNIDsStmt *sql.Stmt bulkSelectStateBlockNIDsStmt *sql.Stmt
} }
func createStateSnapshotTable(db *sql.DB) error { func CreateStateSnapshotTable(db *sql.DB) error {
_, err := db.Exec(stateSnapshotSchema) _, err := db.Exec(stateSnapshotSchema)
return err return err
} }
func prepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) { func PrepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) {
s := &stateSnapshotStatements{} s := &stateSnapshotStatements{}
return s, sqlutil.StatementList{ return s, sqlutil.StatementList{
@ -95,12 +95,10 @@ func (s *stateSnapshotStatements) InsertState(
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, nids types.StateBlockNIDs, ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, nids types.StateBlockNIDs,
) (stateNID types.StateSnapshotNID, err error) { ) (stateNID types.StateSnapshotNID, err error) {
nids = nids[:util.SortAndUnique(nids)] nids = nids[:util.SortAndUnique(nids)]
var id int64 err = sqlutil.TxStmt(txn, s.insertStateStmt).QueryRowContext(ctx, nids.Hash(), int64(roomNID), stateBlockNIDsAsArray(nids)).Scan(&stateNID)
err = sqlutil.TxStmt(txn, s.insertStateStmt).QueryRowContext(ctx, nids.Hash(), int64(roomNID), stateBlockNIDsAsArray(nids)).Scan(&id)
if err != nil { if err != nil {
return 0, err return 0, err
} }
stateNID = types.StateSnapshotNID(id)
return return
} }
@ -119,9 +117,9 @@ func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs(
defer rows.Close() // nolint: errcheck defer rows.Close() // nolint: errcheck
results := make([]types.StateBlockNIDList, len(stateNIDs)) results := make([]types.StateBlockNIDList, len(stateNIDs))
i := 0 i := 0
var stateBlockNIDs pq.Int64Array
for ; rows.Next(); i++ { for ; rows.Next(); i++ {
result := &results[i] result := &results[i]
var stateBlockNIDs pq.Int64Array
if err = rows.Scan(&result.StateSnapshotNID, &stateBlockNIDs); err != nil { if err = rows.Scan(&result.StateSnapshotNID, &stateBlockNIDs); err != nil {
return nil, err return nil, err
} }

View file

@ -80,19 +80,19 @@ func (d *Database) create(db *sql.DB) error {
if err := CreateEventsTable(db); err != nil { if err := CreateEventsTable(db); err != nil {
return err return err
} }
if err := createRoomsTable(db); err != nil { if err := CreateRoomsTable(db); err != nil {
return err return err
} }
if err := createStateBlockTable(db); err != nil { if err := CreateStateBlockTable(db); err != nil {
return err return err
} }
if err := createStateSnapshotTable(db); err != nil { if err := CreateStateSnapshotTable(db); err != nil {
return err return err
} }
if err := CreatePrevEventsTable(db); err != nil { if err := CreatePrevEventsTable(db); err != nil {
return err return err
} }
if err := createRoomAliasesTable(db); err != nil { if err := CreateRoomAliasesTable(db); err != nil {
return err return err
} }
if err := CreateInvitesTable(db); err != nil { if err := CreateInvitesTable(db); err != nil {
@ -128,15 +128,15 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room
if err != nil { if err != nil {
return err return err
} }
rooms, err := prepareRoomsTable(db) rooms, err := PrepareRoomsTable(db)
if err != nil { if err != nil {
return err return err
} }
stateBlock, err := prepareStateBlockTable(db) stateBlock, err := PrepareStateBlockTable(db)
if err != nil { if err != nil {
return err return err
} }
stateSnapshot, err := prepareStateSnapshotTable(db) stateSnapshot, err := PrepareStateSnapshotTable(db)
if err != nil { if err != nil {
return err return err
} }
@ -144,7 +144,7 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room
if err != nil { if err != nil {
return err return err
} }
roomAliases, err := prepareRoomAliasesTable(db) roomAliases, err := PrepareRoomAliasesTable(db)
if err != nil { if err != nil {
return err return err
} }

View file

@ -1216,7 +1216,7 @@ func (d *Database) GetKnownUsers(ctx context.Context, userID, searchString strin
// GetKnownRooms returns a list of all rooms we know about. // GetKnownRooms returns a list of all rooms we know about.
func (d *Database) GetKnownRooms(ctx context.Context) ([]string, error) { func (d *Database) GetKnownRooms(ctx context.Context) ([]string, error) {
return d.RoomsTable.SelectRoomIDs(ctx, nil) return d.RoomsTable.SelectRoomIDsWithEvents(ctx, nil)
} }
// ForgetRoom sets a users room to forgotten // ForgetRoom sets a users room to forgotten

View file

@ -247,9 +247,9 @@ func (s *eventStatements) BulkSelectStateEventByNID(
ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
stateKeyTuples []types.StateKeyTuple, stateKeyTuples []types.StateKeyTuple,
) ([]types.StateEntry, error) { ) ([]types.StateEntry, error) {
tuples := stateKeyTupleSorter(stateKeyTuples) tuples := types.StateKeyTupleSorter(stateKeyTuples)
sort.Sort(tuples) sort.Sort(tuples)
eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays() eventTypeNIDArray, eventStateKeyNIDArray := tuples.TypesAndStateKeysAsArrays()
params := make([]interface{}, 0, len(eventNIDs)+len(eventTypeNIDArray)+len(eventStateKeyNIDArray)) params := make([]interface{}, 0, len(eventNIDs)+len(eventTypeNIDArray)+len(eventStateKeyNIDArray))
selectOrig := strings.Replace(bulkSelectStateEventByNIDSQL, "($1)", sqlutil.QueryVariadic(len(eventNIDs)), 1) selectOrig := strings.Replace(bulkSelectStateEventByNIDSQL, "($1)", sqlutil.QueryVariadic(len(eventNIDs)), 1)
for _, v := range eventNIDs { for _, v := range eventNIDs {

View file

@ -63,12 +63,12 @@ type roomAliasesStatements struct {
deleteRoomAliasStmt *sql.Stmt deleteRoomAliasStmt *sql.Stmt
} }
func createRoomAliasesTable(db *sql.DB) error { func CreateRoomAliasesTable(db *sql.DB) error {
_, err := db.Exec(roomAliasesSchema) _, err := db.Exec(roomAliasesSchema)
return err return err
} }
func prepareRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) { func PrepareRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) {
s := &roomAliasesStatements{ s := &roomAliasesStatements{
db: db, db: db,
} }
@ -113,8 +113,8 @@ func (s *roomAliasesStatements) SelectAliasesFromRoomID(
defer internal.CloseAndLogIfError(ctx, rows, "selectAliasesFromRoomID: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "selectAliasesFromRoomID: rows.close() failed")
var alias string
for rows.Next() { for rows.Next() {
var alias string
if err = rows.Scan(&alias); err != nil { if err = rows.Scan(&alias); err != nil {
return return
} }

View file

@ -86,12 +86,12 @@ type roomStatements struct {
selectRoomIDsStmt *sql.Stmt selectRoomIDsStmt *sql.Stmt
} }
func createRoomsTable(db *sql.DB) error { func CreateRoomsTable(db *sql.DB) error {
_, err := db.Exec(roomsSchema) _, err := db.Exec(roomsSchema)
return err return err
} }
func prepareRoomsTable(db *sql.DB) (tables.Rooms, error) { func PrepareRoomsTable(db *sql.DB) (tables.Rooms, error) {
s := &roomStatements{ s := &roomStatements{
db: db, db: db,
} }
@ -108,7 +108,7 @@ func prepareRoomsTable(db *sql.DB) (tables.Rooms, error) {
}.Prepare(db) }.Prepare(db)
} }
func (s *roomStatements) SelectRoomIDs(ctx context.Context, txn *sql.Tx) ([]string, error) { func (s *roomStatements) SelectRoomIDsWithEvents(ctx context.Context, txn *sql.Tx) ([]string, error) {
stmt := sqlutil.TxStmt(txn, s.selectRoomIDsStmt) stmt := sqlutil.TxStmt(txn, s.selectRoomIDsStmt)
rows, err := stmt.QueryContext(ctx) rows, err := stmt.QueryContext(ctx)
if err != nil { if err != nil {
@ -116,8 +116,8 @@ func (s *roomStatements) SelectRoomIDs(ctx context.Context, txn *sql.Tx) ([]stri
} }
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomIDsStmt: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "selectRoomIDsStmt: rows.close() failed")
var roomIDs []string var roomIDs []string
var roomID string
for rows.Next() { for rows.Next() {
var roomID string
if err = rows.Scan(&roomID); err != nil { if err = rows.Scan(&roomID); err != nil {
return nil, err return nil, err
} }
@ -241,9 +241,9 @@ func (s *roomStatements) SelectRoomVersionsForRoomNIDs(
} }
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomVersionsForRoomNIDsStmt: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "selectRoomVersionsForRoomNIDsStmt: rows.close() failed")
result := make(map[types.RoomNID]gomatrixserverlib.RoomVersion) result := make(map[types.RoomNID]gomatrixserverlib.RoomVersion)
var roomNID types.RoomNID
var roomVersion gomatrixserverlib.RoomVersion
for rows.Next() { for rows.Next() {
var roomNID types.RoomNID
var roomVersion gomatrixserverlib.RoomVersion
if err = rows.Scan(&roomNID, &roomVersion); err != nil { if err = rows.Scan(&roomNID, &roomVersion); err != nil {
return nil, err return nil, err
} }
@ -270,8 +270,8 @@ func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, txn *sql.Tx, roo
} }
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomIDsStmt: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomIDsStmt: rows.close() failed")
var roomIDs []string var roomIDs []string
var roomID string
for rows.Next() { for rows.Next() {
var roomID string
if err = rows.Scan(&roomID); err != nil { if err = rows.Scan(&roomID); err != nil {
return nil, err return nil, err
} }
@ -298,8 +298,8 @@ func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, txn *sql.Tx, ro
} }
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomNIDsStmt: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomNIDsStmt: rows.close() failed")
var roomNIDs []types.RoomNID var roomNIDs []types.RoomNID
var roomNID types.RoomNID
for rows.Next() { for rows.Next() {
var roomNID types.RoomNID
if err = rows.Scan(&roomNID); err != nil { if err = rows.Scan(&roomNID); err != nil {
return nil, err return nil, err
} }

View file

@ -20,7 +20,6 @@ import (
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"fmt" "fmt"
"sort"
"strings" "strings"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
@ -64,12 +63,12 @@ type stateBlockStatements struct {
bulkSelectStateBlockEntriesStmt *sql.Stmt bulkSelectStateBlockEntriesStmt *sql.Stmt
} }
func createStateBlockTable(db *sql.DB) error { func CreateStateBlockTable(db *sql.DB) error {
_, err := db.Exec(stateDataSchema) _, err := db.Exec(stateDataSchema)
return err return err
} }
func prepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) { func PrepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) {
s := &stateBlockStatements{ s := &stateBlockStatements{
db: db, db: db,
} }
@ -85,9 +84,9 @@ func (s *stateBlockStatements) BulkInsertStateData(
entries types.StateEntries, entries types.StateEntries,
) (id types.StateBlockNID, err error) { ) (id types.StateBlockNID, err error) {
entries = entries[:util.SortAndUnique(entries)] entries = entries[:util.SortAndUnique(entries)]
nids := types.EventNIDs{} // zero slice to not store 'null' in the DB nids := make(types.EventNIDs, entries.Len())
for _, e := range entries { for i := range entries {
nids = append(nids, e.EventNID) nids[i] = entries[i].EventNID
} }
js, err := json.Marshal(nids) js, err := json.Marshal(nids)
if err != nil { if err != nil {
@ -122,13 +121,13 @@ func (s *stateBlockStatements) BulkSelectStateBlockEntries(
results := make([][]types.EventNID, len(stateBlockNIDs)) results := make([][]types.EventNID, len(stateBlockNIDs))
i := 0 i := 0
var stateBlockNID types.StateBlockNID
var result json.RawMessage
for ; rows.Next(); i++ { for ; rows.Next(); i++ {
var stateBlockNID types.StateBlockNID
var result json.RawMessage
if err = rows.Scan(&stateBlockNID, &result); err != nil { if err = rows.Scan(&stateBlockNID, &result); err != nil {
return nil, err return nil, err
} }
r := []types.EventNID{} var r []types.EventNID
if err = json.Unmarshal(result, &r); err != nil { if err = json.Unmarshal(result, &r); err != nil {
return nil, fmt.Errorf("json.Unmarshal: %w", err) return nil, fmt.Errorf("json.Unmarshal: %w", err)
} }
@ -142,35 +141,3 @@ func (s *stateBlockStatements) BulkSelectStateBlockEntries(
} }
return results, err return results, err
} }
type stateKeyTupleSorter []types.StateKeyTuple
func (s stateKeyTupleSorter) Len() int { return len(s) }
func (s stateKeyTupleSorter) Less(i, j int) bool { return s[i].LessThan(s[j]) }
func (s stateKeyTupleSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
// Check whether a tuple is in the list. Assumes that the list is sorted.
func (s stateKeyTupleSorter) contains(value types.StateKeyTuple) bool {
i := sort.Search(len(s), func(i int) bool { return !s[i].LessThan(value) })
return i < len(s) && s[i] == value
}
// List the unique eventTypeNIDs and eventStateKeyNIDs.
// Assumes that the list is sorted.
func (s stateKeyTupleSorter) typesAndStateKeysAsArrays() (eventTypeNIDs []int64, eventStateKeyNIDs []int64) {
eventTypeNIDs = make([]int64, len(s))
eventStateKeyNIDs = make([]int64, len(s))
for i := range s {
eventTypeNIDs[i] = int64(s[i].EventTypeNID)
eventStateKeyNIDs[i] = int64(s[i].EventStateKeyNID)
}
eventTypeNIDs = eventTypeNIDs[:util.SortAndUnique(int64Sorter(eventTypeNIDs))]
eventStateKeyNIDs = eventStateKeyNIDs[:util.SortAndUnique(int64Sorter(eventStateKeyNIDs))]
return
}
type int64Sorter []int64
func (s int64Sorter) Len() int { return len(s) }
func (s int64Sorter) Less(i, j int) bool { return s[i] < s[j] }
func (s int64Sorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }

View file

@ -1,86 +0,0 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"sort"
"testing"
"github.com/matrix-org/dendrite/roomserver/types"
)
func TestStateKeyTupleSorter(t *testing.T) {
input := stateKeyTupleSorter{
{EventTypeNID: 1, EventStateKeyNID: 2},
{EventTypeNID: 1, EventStateKeyNID: 4},
{EventTypeNID: 2, EventStateKeyNID: 2},
{EventTypeNID: 1, EventStateKeyNID: 1},
}
want := []types.StateKeyTuple{
{EventTypeNID: 1, EventStateKeyNID: 1},
{EventTypeNID: 1, EventStateKeyNID: 2},
{EventTypeNID: 1, EventStateKeyNID: 4},
{EventTypeNID: 2, EventStateKeyNID: 2},
}
doNotWant := []types.StateKeyTuple{
{EventTypeNID: 0, EventStateKeyNID: 0},
{EventTypeNID: 1, EventStateKeyNID: 3},
{EventTypeNID: 2, EventStateKeyNID: 1},
{EventTypeNID: 3, EventStateKeyNID: 1},
}
wantTypeNIDs := []int64{1, 2}
wantStateKeyNIDs := []int64{1, 2, 4}
// Sort the input and check it's in the right order.
sort.Sort(input)
gotTypeNIDs, gotStateKeyNIDs := input.typesAndStateKeysAsArrays()
for i := range want {
if input[i] != want[i] {
t.Errorf("Wanted %#v at index %d got %#v", want[i], i, input[i])
}
if !input.contains(want[i]) {
t.Errorf("Wanted %#v.contains(%#v) to be true but got false", input, want[i])
}
}
for i := range doNotWant {
if input.contains(doNotWant[i]) {
t.Errorf("Wanted %#v.contains(%#v) to be false but got true", input, doNotWant[i])
}
}
if len(wantTypeNIDs) != len(gotTypeNIDs) {
t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
}
for i := range wantTypeNIDs {
if wantTypeNIDs[i] != gotTypeNIDs[i] {
t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
}
}
if len(wantStateKeyNIDs) != len(gotStateKeyNIDs) {
t.Fatalf("Wanted state key NIDs %#v got %#v", wantStateKeyNIDs, gotStateKeyNIDs)
}
for i := range wantStateKeyNIDs {
if wantStateKeyNIDs[i] != gotStateKeyNIDs[i] {
t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
}
}
}

View file

@ -68,12 +68,12 @@ type stateSnapshotStatements struct {
bulkSelectStateBlockNIDsStmt *sql.Stmt bulkSelectStateBlockNIDsStmt *sql.Stmt
} }
func createStateSnapshotTable(db *sql.DB) error { func CreateStateSnapshotTable(db *sql.DB) error {
_, err := db.Exec(stateSnapshotSchema) _, err := db.Exec(stateSnapshotSchema)
return err return err
} }
func prepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) { func PrepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) {
s := &stateSnapshotStatements{ s := &stateSnapshotStatements{
db: db, db: db,
} }
@ -96,12 +96,10 @@ func (s *stateSnapshotStatements) InsertState(
return return
} }
insertStmt := sqlutil.TxStmt(txn, s.insertStateStmt) insertStmt := sqlutil.TxStmt(txn, s.insertStateStmt)
var id int64 err = insertStmt.QueryRowContext(ctx, stateBlockNIDs.Hash(), int64(roomNID), string(stateBlockNIDsJSON)).Scan(&stateNID)
err = insertStmt.QueryRowContext(ctx, stateBlockNIDs.Hash(), int64(roomNID), string(stateBlockNIDsJSON)).Scan(&id)
if err != nil { if err != nil {
return 0, err return 0, err
} }
stateNID = types.StateSnapshotNID(id)
return return
} }
@ -127,9 +125,9 @@ func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs(
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateBlockNIDs: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateBlockNIDs: rows.close() failed")
results := make([]types.StateBlockNIDList, len(stateNIDs)) results := make([]types.StateBlockNIDList, len(stateNIDs))
i := 0 i := 0
var stateBlockNIDsJSON string
for ; rows.Next(); i++ { for ; rows.Next(); i++ {
result := &results[i] result := &results[i]
var stateBlockNIDsJSON string
if err := rows.Scan(&result.StateSnapshotNID, &stateBlockNIDsJSON); err != nil { if err := rows.Scan(&result.StateSnapshotNID, &stateBlockNIDsJSON); err != nil {
return nil, err return nil, err
} }

View file

@ -89,19 +89,19 @@ func (d *Database) create(db *sql.DB) error {
if err := CreateEventsTable(db); err != nil { if err := CreateEventsTable(db); err != nil {
return err return err
} }
if err := createRoomsTable(db); err != nil { if err := CreateRoomsTable(db); err != nil {
return err return err
} }
if err := createStateBlockTable(db); err != nil { if err := CreateStateBlockTable(db); err != nil {
return err return err
} }
if err := createStateSnapshotTable(db); err != nil { if err := CreateStateSnapshotTable(db); err != nil {
return err return err
} }
if err := CreatePrevEventsTable(db); err != nil { if err := CreatePrevEventsTable(db); err != nil {
return err return err
} }
if err := createRoomAliasesTable(db); err != nil { if err := CreateRoomAliasesTable(db); err != nil {
return err return err
} }
if err := CreateInvitesTable(db); err != nil { if err := CreateInvitesTable(db); err != nil {
@ -137,15 +137,15 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room
if err != nil { if err != nil {
return err return err
} }
rooms, err := prepareRoomsTable(db) rooms, err := PrepareRoomsTable(db)
if err != nil { if err != nil {
return err return err
} }
stateBlock, err := prepareStateBlockTable(db) stateBlock, err := PrepareStateBlockTable(db)
if err != nil { if err != nil {
return err return err
} }
stateSnapshot, err := prepareStateSnapshotTable(db) stateSnapshot, err := PrepareStateSnapshotTable(db)
if err != nil { if err != nil {
return err return err
} }
@ -153,7 +153,7 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room
if err != nil { if err != nil {
return err return err
} }
roomAliases, err := prepareRoomAliasesTable(db) roomAliases, err := PrepareRoomAliasesTable(db)
if err != nil { if err != nil {
return err return err
} }

View file

@ -39,7 +39,7 @@ func mustCreateEventsTable(t *testing.T, dbType test.DBType) (tables.Events, fun
} }
func Test_EventsTable(t *testing.T) { func Test_EventsTable(t *testing.T) {
alice := test.NewUser() alice := test.NewUser(t)
room := test.NewRoom(t, alice) room := test.NewRoom(t, alice)
ctx := context.Background() ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {

View file

@ -72,7 +72,7 @@ type Rooms interface {
UpdateLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID, stateSnapshotNID types.StateSnapshotNID) error UpdateLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID, stateSnapshotNID types.StateSnapshotNID) error
SelectRoomVersionsForRoomNIDs(ctx context.Context, txn *sql.Tx, roomNID []types.RoomNID) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error) SelectRoomVersionsForRoomNIDs(ctx context.Context, txn *sql.Tx, roomNID []types.RoomNID) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error)
SelectRoomInfo(ctx context.Context, txn *sql.Tx, roomID string) (*types.RoomInfo, error) SelectRoomInfo(ctx context.Context, txn *sql.Tx, roomID string) (*types.RoomInfo, error)
SelectRoomIDs(ctx context.Context, txn *sql.Tx) ([]string, error) SelectRoomIDsWithEvents(ctx context.Context, txn *sql.Tx) ([]string, error)
BulkSelectRoomIDs(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) ([]string, error) BulkSelectRoomIDs(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) ([]string, error)
BulkSelectRoomNIDs(ctx context.Context, txn *sql.Tx, roomIDs []string) ([]types.RoomNID, error) BulkSelectRoomNIDs(ctx context.Context, txn *sql.Tx, roomIDs []string) ([]types.RoomNID, error)
} }

View file

@ -38,7 +38,7 @@ func mustCreatePreviousEventsTable(t *testing.T, dbType test.DBType) (tab tables
func TestPreviousEventsTable(t *testing.T) { func TestPreviousEventsTable(t *testing.T) {
ctx := context.Background() ctx := context.Background()
alice := test.NewUser() alice := test.NewUser(t)
room := test.NewRoom(t, alice) room := test.NewRoom(t, alice)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
tab, close := mustCreatePreviousEventsTable(t, dbType) tab, close := mustCreatePreviousEventsTable(t, dbType)

View file

@ -38,7 +38,7 @@ func mustCreatePublishedTable(t *testing.T, dbType test.DBType) (tab tables.Publ
func TestPublishedTable(t *testing.T) { func TestPublishedTable(t *testing.T) {
ctx := context.Background() ctx := context.Background()
alice := test.NewUser() alice := test.NewUser(t)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
tab, close := mustCreatePublishedTable(t, dbType) tab, close := mustCreatePublishedTable(t, dbType)

View file

@ -0,0 +1,96 @@
package tables_test
import (
"context"
"testing"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/postgres"
"github.com/matrix-org/dendrite/roomserver/storage/sqlite3"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test"
"github.com/stretchr/testify/assert"
)
func mustCreateRoomAliasesTable(t *testing.T, dbType test.DBType) (tab tables.RoomAliases, close func()) {
t.Helper()
connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := sqlutil.Open(&config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
}, sqlutil.NewExclusiveWriter())
assert.NoError(t, err)
switch dbType {
case test.DBTypePostgres:
err = postgres.CreateRoomAliasesTable(db)
assert.NoError(t, err)
tab, err = postgres.PrepareRoomAliasesTable(db)
case test.DBTypeSQLite:
err = sqlite3.CreateRoomAliasesTable(db)
assert.NoError(t, err)
tab, err = sqlite3.PrepareRoomAliasesTable(db)
}
assert.NoError(t, err)
return tab, close
}
func TestRoomAliasesTable(t *testing.T) {
alice := test.NewUser(t)
room := test.NewRoom(t, alice)
room2 := test.NewRoom(t, alice)
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
tab, close := mustCreateRoomAliasesTable(t, dbType)
defer close()
alias, alias2, alias3 := "#alias:localhost", "#alias2:localhost", "#alias3:localhost"
// insert aliases
err := tab.InsertRoomAlias(ctx, nil, alias, room.ID, alice.ID)
assert.NoError(t, err)
err = tab.InsertRoomAlias(ctx, nil, alias2, room.ID, alice.ID)
assert.NoError(t, err)
err = tab.InsertRoomAlias(ctx, nil, alias3, room2.ID, alice.ID)
assert.NoError(t, err)
// verify we can get the roomID for the alias
roomID, err := tab.SelectRoomIDFromAlias(ctx, nil, alias)
assert.NoError(t, err)
assert.Equal(t, room.ID, roomID)
// .. and the creator
creator, err := tab.SelectCreatorIDFromAlias(ctx, nil, alias)
assert.NoError(t, err)
assert.Equal(t, alice.ID, creator)
creator, err = tab.SelectCreatorIDFromAlias(ctx, nil, "#doesntexist:localhost")
assert.NoError(t, err)
assert.Equal(t, "", creator)
roomID, err = tab.SelectRoomIDFromAlias(ctx, nil, "#doesntexist:localhost")
assert.NoError(t, err)
assert.Equal(t, "", roomID)
// get all aliases for a room
aliases, err := tab.SelectAliasesFromRoomID(ctx, nil, room.ID)
assert.NoError(t, err)
assert.Equal(t, []string{alias, alias2}, aliases)
// delete an alias and verify it's deleted
err = tab.DeleteRoomAlias(ctx, nil, alias2)
assert.NoError(t, err)
aliases, err = tab.SelectAliasesFromRoomID(ctx, nil, room.ID)
assert.NoError(t, err)
assert.Equal(t, []string{alias}, aliases)
// deleting the same alias should be a no-op
err = tab.DeleteRoomAlias(ctx, nil, alias2)
assert.NoError(t, err)
// Delete non-existent alias should be a no-op
err = tab.DeleteRoomAlias(ctx, nil, "#doesntexist:localhost")
assert.NoError(t, err)
})
}

View file

@ -0,0 +1,128 @@
package tables_test
import (
"context"
"testing"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/postgres"
"github.com/matrix-org/dendrite/roomserver/storage/sqlite3"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/util"
"github.com/stretchr/testify/assert"
)
func mustCreateRoomsTable(t *testing.T, dbType test.DBType) (tab tables.Rooms, close func()) {
t.Helper()
connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := sqlutil.Open(&config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
}, sqlutil.NewExclusiveWriter())
assert.NoError(t, err)
switch dbType {
case test.DBTypePostgres:
err = postgres.CreateRoomsTable(db)
assert.NoError(t, err)
tab, err = postgres.PrepareRoomsTable(db)
case test.DBTypeSQLite:
err = sqlite3.CreateRoomsTable(db)
assert.NoError(t, err)
tab, err = sqlite3.PrepareRoomsTable(db)
}
assert.NoError(t, err)
return tab, close
}
func TestRoomsTable(t *testing.T) {
alice := test.NewUser(t)
room := test.NewRoom(t, alice)
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
tab, close := mustCreateRoomsTable(t, dbType)
defer close()
wantRoomNID, err := tab.InsertRoomNID(ctx, nil, room.ID, room.Version)
assert.NoError(t, err)
// Create dummy room
_, err = tab.InsertRoomNID(ctx, nil, util.RandomString(16), room.Version)
assert.NoError(t, err)
gotRoomNID, err := tab.SelectRoomNID(ctx, nil, room.ID)
assert.NoError(t, err)
assert.Equal(t, wantRoomNID, gotRoomNID)
// Ensure non existent roomNID errors
roomNID, err := tab.SelectRoomNID(ctx, nil, "!doesnotexist:localhost")
assert.Error(t, err)
assert.Equal(t, types.RoomNID(0), roomNID)
roomInfo, err := tab.SelectRoomInfo(ctx, nil, room.ID)
assert.NoError(t, err)
assert.Equal(t, &types.RoomInfo{
RoomNID: wantRoomNID,
RoomVersion: room.Version,
StateSnapshotNID: 0,
IsStub: true, // there are no latestEventNIDs
}, roomInfo)
roomInfo, err = tab.SelectRoomInfo(ctx, nil, "!doesnotexist:localhost")
assert.NoError(t, err)
assert.Nil(t, roomInfo)
// There are no rooms with latestEventNIDs yet
roomIDs, err := tab.SelectRoomIDsWithEvents(ctx, nil)
assert.NoError(t, err)
assert.Equal(t, 0, len(roomIDs))
roomVersions, err := tab.SelectRoomVersionsForRoomNIDs(ctx, nil, []types.RoomNID{wantRoomNID, 1337})
assert.NoError(t, err)
assert.Equal(t, roomVersions[wantRoomNID], room.Version)
// Room does not exist
_, ok := roomVersions[1337]
assert.False(t, ok)
roomIDs, err = tab.BulkSelectRoomIDs(ctx, nil, []types.RoomNID{wantRoomNID, 1337})
assert.NoError(t, err)
assert.Equal(t, []string{room.ID}, roomIDs)
roomNIDs, err := tab.BulkSelectRoomNIDs(ctx, nil, []string{room.ID, "!doesnotexist:localhost"})
assert.NoError(t, err)
assert.Equal(t, []types.RoomNID{wantRoomNID}, roomNIDs)
wantEventNIDs := []types.EventNID{1, 2, 3}
lastEventSentNID := types.EventNID(3)
stateSnapshotNID := types.StateSnapshotNID(1)
// make the room "usable"
err = tab.UpdateLatestEventNIDs(ctx, nil, wantRoomNID, wantEventNIDs, lastEventSentNID, stateSnapshotNID)
assert.NoError(t, err)
roomInfo, err = tab.SelectRoomInfo(ctx, nil, room.ID)
assert.NoError(t, err)
assert.Equal(t, &types.RoomInfo{
RoomNID: wantRoomNID,
RoomVersion: room.Version,
StateSnapshotNID: 1,
IsStub: false,
}, roomInfo)
eventNIDs, snapshotNID, err := tab.SelectLatestEventNIDs(ctx, nil, wantRoomNID)
assert.NoError(t, err)
assert.Equal(t, wantEventNIDs, eventNIDs)
assert.Equal(t, types.StateSnapshotNID(1), snapshotNID)
// Again, doesn't exist
_, _, err = tab.SelectLatestEventNIDs(ctx, nil, 1337)
assert.Error(t, err)
eventNIDs, eventNID, snapshotNID, err := tab.SelectLatestEventsNIDsForUpdate(ctx, nil, wantRoomNID)
assert.NoError(t, err)
assert.Equal(t, wantEventNIDs, eventNIDs)
assert.Equal(t, types.EventNID(3), eventNID)
assert.Equal(t, types.StateSnapshotNID(1), snapshotNID)
})
}

View file

@ -0,0 +1,92 @@
package tables_test
import (
"context"
"testing"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/postgres"
"github.com/matrix-org/dendrite/roomserver/storage/sqlite3"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test"
"github.com/stretchr/testify/assert"
)
func mustCreateStateBlockTable(t *testing.T, dbType test.DBType) (tab tables.StateBlock, close func()) {
t.Helper()
connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := sqlutil.Open(&config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
}, sqlutil.NewExclusiveWriter())
assert.NoError(t, err)
switch dbType {
case test.DBTypePostgres:
err = postgres.CreateStateBlockTable(db)
assert.NoError(t, err)
tab, err = postgres.PrepareStateBlockTable(db)
case test.DBTypeSQLite:
err = sqlite3.CreateStateBlockTable(db)
assert.NoError(t, err)
tab, err = sqlite3.PrepareStateBlockTable(db)
}
assert.NoError(t, err)
return tab, close
}
func TestStateBlockTable(t *testing.T) {
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
tab, close := mustCreateStateBlockTable(t, dbType)
defer close()
// generate some dummy data
var entries types.StateEntries
for i := 0; i < 100; i++ {
entry := types.StateEntry{
EventNID: types.EventNID(i),
}
entries = append(entries, entry)
}
stateBlockNID, err := tab.BulkInsertStateData(ctx, nil, entries)
assert.NoError(t, err)
assert.Equal(t, types.StateBlockNID(1), stateBlockNID)
// generate a different hash, to get a new StateBlockNID
var entries2 types.StateEntries
for i := 100; i < 300; i++ {
entry := types.StateEntry{
EventNID: types.EventNID(i),
}
entries2 = append(entries2, entry)
}
stateBlockNID, err = tab.BulkInsertStateData(ctx, nil, entries2)
assert.NoError(t, err)
assert.Equal(t, types.StateBlockNID(2), stateBlockNID)
eventNIDs, err := tab.BulkSelectStateBlockEntries(ctx, nil, types.StateBlockNIDs{1, 2})
assert.NoError(t, err)
assert.Equal(t, len(entries), len(eventNIDs[0]))
assert.Equal(t, len(entries2), len(eventNIDs[1]))
// try to get a StateBlockNID which does not exist
_, err = tab.BulkSelectStateBlockEntries(ctx, nil, types.StateBlockNIDs{5})
assert.Error(t, err)
// This should return an error, since we can only retrieve 1 StateBlock
_, err = tab.BulkSelectStateBlockEntries(ctx, nil, types.StateBlockNIDs{1, 5})
assert.Error(t, err)
for i := 0; i < 65555; i++ {
entry := types.StateEntry{
EventNID: types.EventNID(i),
}
entries2 = append(entries2, entry)
}
stateBlockNID, err = tab.BulkInsertStateData(ctx, nil, entries2)
assert.NoError(t, err)
assert.Equal(t, types.StateBlockNID(3), stateBlockNID)
})
}

View file

@ -0,0 +1,86 @@
package tables_test
import (
"context"
"testing"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/postgres"
"github.com/matrix-org/dendrite/roomserver/storage/sqlite3"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test"
"github.com/stretchr/testify/assert"
)
func mustCreateStateSnapshotTable(t *testing.T, dbType test.DBType) (tab tables.StateSnapshot, close func()) {
t.Helper()
connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := sqlutil.Open(&config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
}, sqlutil.NewExclusiveWriter())
assert.NoError(t, err)
switch dbType {
case test.DBTypePostgres:
err = postgres.CreateStateSnapshotTable(db)
assert.NoError(t, err)
tab, err = postgres.PrepareStateSnapshotTable(db)
case test.DBTypeSQLite:
err = sqlite3.CreateStateSnapshotTable(db)
assert.NoError(t, err)
tab, err = sqlite3.PrepareStateSnapshotTable(db)
}
assert.NoError(t, err)
return tab, close
}
func TestStateSnapshotTable(t *testing.T) {
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
tab, close := mustCreateStateSnapshotTable(t, dbType)
defer close()
// generate some dummy data
var stateBlockNIDs types.StateBlockNIDs
for i := 0; i < 100; i++ {
stateBlockNIDs = append(stateBlockNIDs, types.StateBlockNID(i))
}
stateNID, err := tab.InsertState(ctx, nil, 1, stateBlockNIDs)
assert.NoError(t, err)
assert.Equal(t, types.StateSnapshotNID(1), stateNID)
// verify ON CONFLICT; Note: this updates the sequence!
stateNID, err = tab.InsertState(ctx, nil, 1, stateBlockNIDs)
assert.NoError(t, err)
assert.Equal(t, types.StateSnapshotNID(1), stateNID)
// create a second snapshot
var stateBlockNIDs2 types.StateBlockNIDs
for i := 100; i < 150; i++ {
stateBlockNIDs2 = append(stateBlockNIDs2, types.StateBlockNID(i))
}
stateNID, err = tab.InsertState(ctx, nil, 1, stateBlockNIDs2)
assert.NoError(t, err)
// StateSnapshotNID is now 3, since the DO UPDATE SET statement incremented the sequence
assert.Equal(t, types.StateSnapshotNID(3), stateNID)
nidLists, err := tab.BulkSelectStateBlockNIDs(ctx, nil, []types.StateSnapshotNID{1, 3})
assert.NoError(t, err)
assert.Equal(t, stateBlockNIDs, types.StateBlockNIDs(nidLists[0].StateBlockNIDs))
assert.Equal(t, stateBlockNIDs2, types.StateBlockNIDs(nidLists[1].StateBlockNIDs))
// check we get an error if the state snapshot does not exist
_, err = tab.BulkSelectStateBlockNIDs(ctx, nil, []types.StateSnapshotNID{2})
assert.Error(t, err)
// create a second snapshot
for i := 0; i < 65555; i++ {
stateBlockNIDs2 = append(stateBlockNIDs2, types.StateBlockNID(i))
}
_, err = tab.InsertState(ctx, nil, 1, stateBlockNIDs2)
assert.NoError(t, err)
})
}

View file

@ -21,6 +21,7 @@ import (
"strings" "strings"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"golang.org/x/crypto/blake2b" "golang.org/x/crypto/blake2b"
) )
@ -97,6 +98,38 @@ func (a StateKeyTuple) LessThan(b StateKeyTuple) bool {
return a.EventStateKeyNID < b.EventStateKeyNID return a.EventStateKeyNID < b.EventStateKeyNID
} }
type StateKeyTupleSorter []StateKeyTuple
func (s StateKeyTupleSorter) Len() int { return len(s) }
func (s StateKeyTupleSorter) Less(i, j int) bool { return s[i].LessThan(s[j]) }
func (s StateKeyTupleSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
// Check whether a tuple is in the list. Assumes that the list is sorted.
func (s StateKeyTupleSorter) contains(value StateKeyTuple) bool {
i := sort.Search(len(s), func(i int) bool { return !s[i].LessThan(value) })
return i < len(s) && s[i] == value
}
// List the unique eventTypeNIDs and eventStateKeyNIDs.
// Assumes that the list is sorted.
func (s StateKeyTupleSorter) TypesAndStateKeysAsArrays() (eventTypeNIDs []int64, eventStateKeyNIDs []int64) {
eventTypeNIDs = make([]int64, len(s))
eventStateKeyNIDs = make([]int64, len(s))
for i := range s {
eventTypeNIDs[i] = int64(s[i].EventTypeNID)
eventStateKeyNIDs[i] = int64(s[i].EventStateKeyNID)
}
eventTypeNIDs = eventTypeNIDs[:util.SortAndUnique(int64Sorter(eventTypeNIDs))]
eventStateKeyNIDs = eventStateKeyNIDs[:util.SortAndUnique(int64Sorter(eventStateKeyNIDs))]
return
}
type int64Sorter []int64
func (s int64Sorter) Len() int { return len(s) }
func (s int64Sorter) Less(i, j int) bool { return s[i] < s[j] }
func (s int64Sorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
// A StateEntry is an entry in the room state of a matrix room. // A StateEntry is an entry in the room state of a matrix room.
type StateEntry struct { type StateEntry struct {
StateKeyTuple StateKeyTuple

View file

@ -1,6 +1,7 @@
package types package types
import ( import (
"sort"
"testing" "testing"
) )
@ -24,3 +25,66 @@ func TestDeduplicateStateEntries(t *testing.T) {
} }
} }
} }
func TestStateKeyTupleSorter(t *testing.T) {
input := StateKeyTupleSorter{
{EventTypeNID: 1, EventStateKeyNID: 2},
{EventTypeNID: 1, EventStateKeyNID: 4},
{EventTypeNID: 2, EventStateKeyNID: 2},
{EventTypeNID: 1, EventStateKeyNID: 1},
}
want := []StateKeyTuple{
{EventTypeNID: 1, EventStateKeyNID: 1},
{EventTypeNID: 1, EventStateKeyNID: 2},
{EventTypeNID: 1, EventStateKeyNID: 4},
{EventTypeNID: 2, EventStateKeyNID: 2},
}
doNotWant := []StateKeyTuple{
{EventTypeNID: 0, EventStateKeyNID: 0},
{EventTypeNID: 1, EventStateKeyNID: 3},
{EventTypeNID: 2, EventStateKeyNID: 1},
{EventTypeNID: 3, EventStateKeyNID: 1},
}
wantTypeNIDs := []int64{1, 2}
wantStateKeyNIDs := []int64{1, 2, 4}
// Sort the input and check it's in the right order.
sort.Sort(input)
gotTypeNIDs, gotStateKeyNIDs := input.TypesAndStateKeysAsArrays()
for i := range want {
if input[i] != want[i] {
t.Errorf("Wanted %#v at index %d got %#v", want[i], i, input[i])
}
if !input.contains(want[i]) {
t.Errorf("Wanted %#v.contains(%#v) to be true but got false", input, want[i])
}
}
for i := range doNotWant {
if input.contains(doNotWant[i]) {
t.Errorf("Wanted %#v.contains(%#v) to be false but got true", input, doNotWant[i])
}
}
if len(wantTypeNIDs) != len(gotTypeNIDs) {
t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
}
for i := range wantTypeNIDs {
if wantTypeNIDs[i] != gotTypeNIDs[i] {
t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
}
}
if len(wantStateKeyNIDs) != len(gotStateKeyNIDs) {
t.Fatalf("Wanted state key NIDs %#v got %#v", wantStateKeyNIDs, gotStateKeyNIDs)
}
for i := range wantStateKeyNIDs {
if wantStateKeyNIDs[i] != gotStateKeyNIDs[i] {
t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
}
}
}

View file

@ -138,9 +138,12 @@ func (s *PresenceConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
presence := msg.Header.Get("presence") presence := msg.Header.Get("presence")
timestamp := msg.Header.Get("last_active_ts") timestamp := msg.Header.Get("last_active_ts")
fromSync, _ := strconv.ParseBool(msg.Header.Get("from_sync")) fromSync, _ := strconv.ParseBool(msg.Header.Get("from_sync"))
logrus.Debugf("syncAPI received presence event: %+v", msg.Header) logrus.Debugf("syncAPI received presence event: %+v", msg.Header)
if fromSync { // do not process local presence changes; we already did this synchronously.
return true
}
ts, err := strconv.Atoi(timestamp) ts, err := strconv.Atoi(timestamp)
if err != nil { if err != nil {
return true return true
@ -151,15 +154,19 @@ func (s *PresenceConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
newMsg := msg.Header.Get("status_msg") newMsg := msg.Header.Get("status_msg")
statusMsg = &newMsg statusMsg = &newMsg
} }
// OK is already checked, so no need to do it again // already checked, so no need to check error
p, _ := types.PresenceFromString(presence) p, _ := types.PresenceFromString(presence)
pos, err := s.db.UpdatePresence(ctx, userID, p, statusMsg, gomatrixserverlib.Timestamp(ts), fromSync)
if err != nil {
return true
}
s.stream.Advance(pos)
s.notifier.OnNewPresence(types.StreamingToken{PresencePosition: pos}, userID)
s.EmitPresence(ctx, userID, p, statusMsg, ts, fromSync)
return true return true
} }
func (s *PresenceConsumer) EmitPresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, ts int, fromSync bool) {
pos, err := s.db.UpdatePresence(ctx, userID, presence, statusMsg, gomatrixserverlib.Timestamp(ts), fromSync)
if err != nil {
logrus.WithError(err).WithField("user", userID).WithField("presence", presence).Warn("failed to updated presence for user")
return
}
s.stream.Advance(pos)
s.notifier.OnNewPresence(types.StreamingToken{PresencePosition: pos}, userID)
}

View file

@ -55,7 +55,7 @@ func MustWriteEvents(t *testing.T, db storage.Database, events []*gomatrixserver
func TestWriteEvents(t *testing.T) { func TestWriteEvents(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
alice := test.NewUser() alice := test.NewUser(t)
r := test.NewRoom(t, alice) r := test.NewRoom(t, alice)
db, close := MustCreateDatabase(t, dbType) db, close := MustCreateDatabase(t, dbType)
defer close() defer close()
@ -68,7 +68,7 @@ func TestRecentEventsPDU(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := MustCreateDatabase(t, dbType) db, close := MustCreateDatabase(t, dbType)
defer close() defer close()
alice := test.NewUser() alice := test.NewUser(t)
// dummy room to make sure SQL queries are filtering on room ID // dummy room to make sure SQL queries are filtering on room ID
MustWriteEvents(t, db, test.NewRoom(t, alice).Events()) MustWriteEvents(t, db, test.NewRoom(t, alice).Events())
@ -171,7 +171,7 @@ func TestGetEventsInRangeWithTopologyToken(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := MustCreateDatabase(t, dbType) db, close := MustCreateDatabase(t, dbType)
defer close() defer close()
alice := test.NewUser() alice := test.NewUser(t)
r := test.NewRoom(t, alice) r := test.NewRoom(t, alice)
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
r.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": fmt.Sprintf("hi %d", i)}) r.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": fmt.Sprintf("hi %d", i)})

View file

@ -45,7 +45,7 @@ func newOutputRoomEventsTable(t *testing.T, dbType test.DBType) (tables.Events,
func TestOutputRoomEventsTable(t *testing.T) { func TestOutputRoomEventsTable(t *testing.T) {
ctx := context.Background() ctx := context.Background()
alice := test.NewUser() alice := test.NewUser(t)
room := test.NewRoom(t, alice) room := test.NewRoom(t, alice)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
tab, db, close := newOutputRoomEventsTable(t, dbType) tab, db, close := newOutputRoomEventsTable(t, dbType)

View file

@ -40,7 +40,7 @@ func newTopologyTable(t *testing.T, dbType test.DBType) (tables.Topology, *sql.D
func TestTopologyTable(t *testing.T) { func TestTopologyTable(t *testing.T) {
ctx := context.Background() ctx := context.Background()
alice := test.NewUser() alice := test.NewUser(t)
room := test.NewRoom(t, alice) room := test.NewRoom(t, alice)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
tab, db, close := newTopologyTable(t, dbType) tab, db, close := newTopologyTable(t, dbType)

View file

@ -53,19 +53,24 @@ type RequestPool struct {
streams *streams.Streams streams *streams.Streams
Notifier *notifier.Notifier Notifier *notifier.Notifier
producer PresencePublisher producer PresencePublisher
consumer PresenceConsumer
} }
type PresencePublisher interface { type PresencePublisher interface {
SendPresence(userID string, presence types.Presence, statusMsg *string) error SendPresence(userID string, presence types.Presence, statusMsg *string) error
} }
type PresenceConsumer interface {
EmitPresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, ts int, fromSync bool)
}
// NewRequestPool makes a new RequestPool // NewRequestPool makes a new RequestPool
func NewRequestPool( func NewRequestPool(
db storage.Database, cfg *config.SyncAPI, db storage.Database, cfg *config.SyncAPI,
userAPI userapi.SyncUserAPI, keyAPI keyapi.SyncKeyAPI, userAPI userapi.SyncUserAPI, keyAPI keyapi.SyncKeyAPI,
rsAPI roomserverAPI.SyncRoomserverAPI, rsAPI roomserverAPI.SyncRoomserverAPI,
streams *streams.Streams, notifier *notifier.Notifier, streams *streams.Streams, notifier *notifier.Notifier,
producer PresencePublisher, enableMetrics bool, producer PresencePublisher, consumer PresenceConsumer, enableMetrics bool,
) *RequestPool { ) *RequestPool {
if enableMetrics { if enableMetrics {
prometheus.MustRegister( prometheus.MustRegister(
@ -83,6 +88,7 @@ func NewRequestPool(
streams: streams, streams: streams,
Notifier: notifier, Notifier: notifier,
producer: producer, producer: producer,
consumer: consumer,
} }
go rp.cleanLastSeen() go rp.cleanLastSeen()
go rp.cleanPresence(db, time.Minute*5) go rp.cleanPresence(db, time.Minute*5)
@ -160,6 +166,13 @@ func (rp *RequestPool) updatePresence(db storage.Presence, presence string, user
logrus.WithError(err).Error("Unable to publish presence message from sync") logrus.WithError(err).Error("Unable to publish presence message from sync")
return return
} }
// now synchronously update our view of the world. It's critical we do this before calculating
// the /sync response else we may not return presence: online immediately.
rp.consumer.EmitPresence(
context.Background(), userID, presenceID, newPresence.ClientFields.StatusMsg,
int(gomatrixserverlib.AsTimestamp(time.Now())), true,
)
} }
func (rp *RequestPool) updateLastSeen(req *http.Request, device *userapi.Device) { func (rp *RequestPool) updateLastSeen(req *http.Request, device *userapi.Device) {

View file

@ -38,6 +38,12 @@ func (d dummyDB) MaxStreamPositionForPresence(ctx context.Context) (types.Stream
return 0, nil return 0, nil
} }
type dummyConsumer struct{}
func (d dummyConsumer) EmitPresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, ts int, fromSync bool) {
}
func TestRequestPool_updatePresence(t *testing.T) { func TestRequestPool_updatePresence(t *testing.T) {
type args struct { type args struct {
presence string presence string
@ -45,6 +51,7 @@ func TestRequestPool_updatePresence(t *testing.T) {
sleep time.Duration sleep time.Duration
} }
publisher := &dummyPublisher{} publisher := &dummyPublisher{}
consumer := &dummyConsumer{}
syncMap := sync.Map{} syncMap := sync.Map{}
tests := []struct { tests := []struct {
@ -101,6 +108,7 @@ func TestRequestPool_updatePresence(t *testing.T) {
rp := &RequestPool{ rp := &RequestPool{
presence: &syncMap, presence: &syncMap,
producer: publisher, producer: publisher,
consumer: consumer,
cfg: &config.SyncAPI{ cfg: &config.SyncAPI{
Matrix: &config.Global{ Matrix: &config.Global{
JetStream: config.JetStream{ JetStream: config.JetStream{

View file

@ -70,8 +70,17 @@ func AddPublicRoutes(
Topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputPresenceEvent), Topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputPresenceEvent),
JetStream: js, JetStream: js,
} }
presenceConsumer := consumers.NewPresenceConsumer(
base.ProcessContext, cfg, js, natsClient, syncDB,
notifier, streams.PresenceStreamProvider,
userAPI,
)
requestPool := sync.NewRequestPool(syncDB, cfg, userAPI, keyAPI, rsAPI, streams, notifier, federationPresenceProducer, base.EnableMetrics) requestPool := sync.NewRequestPool(syncDB, cfg, userAPI, keyAPI, rsAPI, streams, notifier, federationPresenceProducer, presenceConsumer, base.EnableMetrics)
if err = presenceConsumer.Start(); err != nil {
logrus.WithError(err).Panicf("failed to start presence consumer")
}
userAPIStreamEventProducer := &producers.UserAPIStreamEventProducer{ userAPIStreamEventProducer := &producers.UserAPIStreamEventProducer{
JetStream: js, JetStream: js,
@ -137,15 +146,6 @@ func AddPublicRoutes(
logrus.WithError(err).Panicf("failed to start receipts consumer") logrus.WithError(err).Panicf("failed to start receipts consumer")
} }
presenceConsumer := consumers.NewPresenceConsumer(
base.ProcessContext, cfg, js, natsClient, syncDB,
notifier, streams.PresenceStreamProvider,
userAPI,
)
if err = presenceConsumer.Start(); err != nil {
logrus.WithError(err).Panicf("failed to start presence consumer")
}
routing.Setup( routing.Setup(
base.PublicClientAPIMux, requestPool, syncDB, userAPI, base.PublicClientAPIMux, requestPool, syncDB, userAPI,
rsAPI, cfg, base.Caches, fts, rsAPI, cfg, base.Caches, fts,

View file

@ -15,9 +15,11 @@ import (
"github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go" "github.com/nats-io/nats.go"
"github.com/tidwall/gjson"
) )
type syncRoomserverAPI struct { type syncRoomserverAPI struct {
@ -86,7 +88,7 @@ func TestSyncAPIAccessTokens(t *testing.T) {
} }
func testSyncAccessTokens(t *testing.T, dbType test.DBType) { func testSyncAccessTokens(t *testing.T, dbType test.DBType) {
user := test.NewUser() user := test.NewUser(t)
room := test.NewRoom(t, user) room := test.NewRoom(t, user)
alice := userapi.Device{ alice := userapi.Device{
ID: "ALICEID", ID: "ALICEID",
@ -96,14 +98,14 @@ func testSyncAccessTokens(t *testing.T, dbType test.DBType) {
AccountType: userapi.AccountTypeUser, AccountType: userapi.AccountTypeUser,
} }
base, close := test.CreateBaseDendrite(t, dbType) base, close := testrig.CreateBaseDendrite(t, dbType)
defer close() defer close()
jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream)
defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream) defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream)
msgs := toNATSMsgs(t, base, room.Events()) msgs := toNATSMsgs(t, base, room.Events())
AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, &syncKeyAPI{}) AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, &syncKeyAPI{})
test.MustPublishMsgs(t, jsctx, msgs...) testrig.MustPublishMsgs(t, jsctx, msgs...)
testCases := []struct { testCases := []struct {
name string name string
@ -173,7 +175,7 @@ func TestSyncAPICreateRoomSyncEarly(t *testing.T) {
} }
func testSyncAPICreateRoomSyncEarly(t *testing.T, dbType test.DBType) { func testSyncAPICreateRoomSyncEarly(t *testing.T, dbType test.DBType) {
user := test.NewUser() user := test.NewUser(t)
room := test.NewRoom(t, user) room := test.NewRoom(t, user)
alice := userapi.Device{ alice := userapi.Device{
ID: "ALICEID", ID: "ALICEID",
@ -183,7 +185,7 @@ func testSyncAPICreateRoomSyncEarly(t *testing.T, dbType test.DBType) {
AccountType: userapi.AccountTypeUser, AccountType: userapi.AccountTypeUser,
} }
base, close := test.CreateBaseDendrite(t, dbType) base, close := testrig.CreateBaseDendrite(t, dbType)
defer close() defer close()
jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream)
@ -198,7 +200,7 @@ func testSyncAPICreateRoomSyncEarly(t *testing.T, dbType test.DBType) {
sinceTokens := make([]string, len(msgs)) sinceTokens := make([]string, len(msgs))
AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, &syncKeyAPI{}) AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, &syncKeyAPI{})
for i, msg := range msgs { for i, msg := range msgs {
test.MustPublishMsgs(t, jsctx, msg) testrig.MustPublishMsgs(t, jsctx, msg)
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
w := httptest.NewRecorder() w := httptest.NewRecorder()
base.PublicClientAPIMux.ServeHTTP(w, test.NewRequest(t, "GET", "/_matrix/client/v3/sync", test.WithQueryParams(map[string]string{ base.PublicClientAPIMux.ServeHTTP(w, test.NewRequest(t, "GET", "/_matrix/client/v3/sync", test.WithQueryParams(map[string]string{
@ -255,6 +257,60 @@ func testSyncAPICreateRoomSyncEarly(t *testing.T, dbType test.DBType) {
} }
} }
// Test that if we hit /sync we get back presence: online, regardless of whether messages get delivered
// via NATS. Regression test for a flakey test "User sees their own presence in a sync"
func TestSyncAPIUpdatePresenceImmediately(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
testSyncAPIUpdatePresenceImmediately(t, dbType)
})
}
func testSyncAPIUpdatePresenceImmediately(t *testing.T, dbType test.DBType) {
user := test.NewUser(t)
alice := userapi.Device{
ID: "ALICEID",
UserID: user.ID,
AccessToken: "ALICE_BEARER_TOKEN",
DisplayName: "Alice",
AccountType: userapi.AccountTypeUser,
}
base, close := testrig.CreateBaseDendrite(t, dbType)
base.Cfg.Global.Presence.EnableOutbound = true
base.Cfg.Global.Presence.EnableInbound = true
defer close()
jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream)
defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream)
AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{}, &syncKeyAPI{})
w := httptest.NewRecorder()
base.PublicClientAPIMux.ServeHTTP(w, test.NewRequest(t, "GET", "/_matrix/client/v3/sync", test.WithQueryParams(map[string]string{
"access_token": alice.AccessToken,
"timeout": "0",
"set_presence": "online",
})))
if w.Code != 200 {
t.Fatalf("got HTTP %d want %d", w.Code, 200)
}
var res types.Response
if err := json.NewDecoder(w.Body).Decode(&res); err != nil {
t.Errorf("failed to decode response body: %s", err)
}
if len(res.Presence.Events) != 1 {
t.Fatalf("expected 1 presence events, got: %+v", res.Presence.Events)
}
if res.Presence.Events[0].Sender != alice.UserID {
t.Errorf("sender: got %v want %v", res.Presence.Events[0].Sender, alice.UserID)
}
if res.Presence.Events[0].Type != "m.presence" {
t.Errorf("type: got %v want %v", res.Presence.Events[0].Type, "m.presence")
}
if gjson.ParseBytes(res.Presence.Events[0].Content).Get("presence").Str != "online" {
t.Errorf("content: not online, got %v", res.Presence.Events[0].Content)
}
}
func toNATSMsgs(t *testing.T, base *base.BaseDendrite, input []*gomatrixserverlib.HeaderedEvent) []*nats.Msg { func toNATSMsgs(t *testing.T, base *base.BaseDendrite, input []*gomatrixserverlib.HeaderedEvent) []*nats.Msg {
result := make([]*nats.Msg, len(input)) result := make([]*nats.Msg, len(input))
for i, ev := range input { for i, ev := range input {
@ -262,7 +318,7 @@ func toNATSMsgs(t *testing.T, base *base.BaseDendrite, input []*gomatrixserverli
if ev.StateKey() != nil { if ev.StateKey() != nil {
addsStateIDs = append(addsStateIDs, ev.EventID()) addsStateIDs = append(addsStateIDs, ev.EventID())
} }
result[i] = test.NewOutputEventMsg(t, base, ev.RoomID(), api.OutputEvent{ result[i] = testrig.NewOutputEventMsg(t, base, ev.RoomID(), api.OutputEvent{
Type: rsapi.OutputTypeNewRoomEvent, Type: rsapi.OutputTypeNewRoomEvent,
NewRoomEvent: &rsapi.OutputNewRoomEvent{ NewRoomEvent: &rsapi.OutputNewRoomEvent{
Event: ev, Event: ev,

View file

@ -44,8 +44,9 @@ func fatalError(t *testing.T, format string, args ...interface{}) {
} }
func createLocalDB(t *testing.T, dbName string) { func createLocalDB(t *testing.T, dbName string) {
if !Quiet { if _, err := exec.LookPath("createdb"); err != nil {
t.Log("Note: tests require a postgres install accessible to the current user") fatalError(t, "Note: tests require a postgres install accessible to the current user")
return
} }
createDB := exec.Command("createdb", dbName) createDB := exec.Command("createdb", dbName)
if !Quiet { if !Quiet {
@ -63,6 +64,9 @@ func createRemoteDB(t *testing.T, dbName, user, connStr string) {
if err != nil { if err != nil {
fatalError(t, "failed to open postgres conn with connstr=%s : %s", connStr, err) fatalError(t, "failed to open postgres conn with connstr=%s : %s", connStr, err)
} }
if err = db.Ping(); err != nil {
fatalError(t, "failed to open postgres conn with connstr=%s : %s", connStr, err)
}
_, err = db.Exec(fmt.Sprintf(`CREATE DATABASE %s;`, dbName)) _, err = db.Exec(fmt.Sprintf(`CREATE DATABASE %s;`, dbName))
if err != nil { if err != nil {
pqErr, ok := err.(*pq.Error) pqErr, ok := err.(*pq.Error)

View file

@ -52,6 +52,24 @@ func WithUnsigned(unsigned interface{}) eventModifier {
} }
} }
func WithKeyID(keyID gomatrixserverlib.KeyID) eventModifier {
return func(e *eventMods) {
e.keyID = keyID
}
}
func WithPrivateKey(pkey ed25519.PrivateKey) eventModifier {
return func(e *eventMods) {
e.privKey = pkey
}
}
func WithOrigin(origin gomatrixserverlib.ServerName) eventModifier {
return func(e *eventMods) {
e.origin = origin
}
}
// Reverse a list of events // Reverse a list of events
func Reversed(in []*gomatrixserverlib.HeaderedEvent) []*gomatrixserverlib.HeaderedEvent { func Reversed(in []*gomatrixserverlib.HeaderedEvent) []*gomatrixserverlib.HeaderedEvent {
out := make([]*gomatrixserverlib.HeaderedEvent, len(in)) out := make([]*gomatrixserverlib.HeaderedEvent, len(in))

View file

@ -2,10 +2,15 @@ package test
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"fmt"
"io" "io"
"net"
"net/http" "net/http"
"net/url" "net/url"
"path/filepath"
"sync"
"testing" "testing"
) )
@ -43,3 +48,45 @@ func NewRequest(t *testing.T, method, path string, opts ...HTTPRequestOpt) *http
} }
return req return req
} }
// ListenAndServe will listen on a random high-numbered port and attach the given router.
// Returns the base URL to send requests to. Call `cancel` to shutdown the server, which will block until it has closed.
func ListenAndServe(t *testing.T, router http.Handler, withTLS bool) (apiURL string, cancel func()) {
listener, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatalf("failed to listen: %s", err)
}
port := listener.Addr().(*net.TCPAddr).Port
srv := http.Server{}
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
srv.Handler = router
var err error
if withTLS {
certFile := filepath.Join(t.TempDir(), "dendrite.cert")
keyFile := filepath.Join(t.TempDir(), "dendrite.key")
err = NewTLSKey(keyFile, certFile)
if err != nil {
t.Errorf("failed to make TLS key: %s", err)
return
}
err = srv.ServeTLS(listener, certFile, keyFile)
} else {
err = srv.Serve(listener)
}
if err != nil && err != http.ErrServerClosed {
t.Logf("Listen failed: %s", err)
}
}()
s := ""
if withTLS {
s = "s"
}
return fmt.Sprintf("http%s://localhost:%d", s, port), func() {
_ = srv.Shutdown(context.Background())
wg.Wait()
}
}

View file

@ -25,103 +25,19 @@ import (
"io/ioutil" "io/ioutil"
"math/big" "math/big"
"os" "os"
"path/filepath"
"strings" "strings"
"time" "time"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
"gopkg.in/yaml.v2"
) )
const ( const (
// ConfigFile is the name of the config file for a server.
ConfigFile = "dendrite.yaml"
// ServerKeyFile is the name of the file holding the matrix server private key. // ServerKeyFile is the name of the file holding the matrix server private key.
ServerKeyFile = "server_key.pem" ServerKeyFile = "server_key.pem"
// TLSCertFile is the name of the file holding the TLS certificate used for federation. // TLSCertFile is the name of the file holding the TLS certificate used for federation.
TLSCertFile = "tls_cert.pem" TLSCertFile = "tls_cert.pem"
// TLSKeyFile is the name of the file holding the TLS key used for federation. // TLSKeyFile is the name of the file holding the TLS key used for federation.
TLSKeyFile = "tls_key.pem" TLSKeyFile = "tls_key.pem"
// MediaDir is the name of the directory used to store media.
MediaDir = "media"
) )
// MakeConfig makes a config suitable for running integration tests.
// Generates new matrix and TLS keys for the server.
func MakeConfig(configDir, kafkaURI, database, host string, startPort int) (*config.Dendrite, int, error) {
var cfg config.Dendrite
cfg.Defaults(true)
port := startPort
assignAddress := func() config.HTTPAddress {
result := config.HTTPAddress(fmt.Sprintf("http://%s:%d", host, port))
port++
return result
}
serverKeyPath := filepath.Join(configDir, ServerKeyFile)
tlsCertPath := filepath.Join(configDir, TLSKeyFile)
tlsKeyPath := filepath.Join(configDir, TLSCertFile)
mediaBasePath := filepath.Join(configDir, MediaDir)
if err := NewMatrixKey(serverKeyPath); err != nil {
return nil, 0, err
}
if err := NewTLSKey(tlsKeyPath, tlsCertPath); err != nil {
return nil, 0, err
}
cfg.Version = config.Version
cfg.Global.ServerName = gomatrixserverlib.ServerName(assignAddress())
cfg.Global.PrivateKeyPath = config.Path(serverKeyPath)
cfg.MediaAPI.BasePath = config.Path(mediaBasePath)
cfg.Global.JetStream.Addresses = []string{kafkaURI}
// TODO: Use different databases for the different schemas.
// Using the same database for every schema currently works because
// the table names are globally unique. But we might not want to
// rely on that in the future.
cfg.AppServiceAPI.Database.ConnectionString = config.DataSource(database)
cfg.FederationAPI.Database.ConnectionString = config.DataSource(database)
cfg.KeyServer.Database.ConnectionString = config.DataSource(database)
cfg.MediaAPI.Database.ConnectionString = config.DataSource(database)
cfg.RoomServer.Database.ConnectionString = config.DataSource(database)
cfg.SyncAPI.Database.ConnectionString = config.DataSource(database)
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(database)
cfg.AppServiceAPI.InternalAPI.Listen = assignAddress()
cfg.FederationAPI.InternalAPI.Listen = assignAddress()
cfg.KeyServer.InternalAPI.Listen = assignAddress()
cfg.MediaAPI.InternalAPI.Listen = assignAddress()
cfg.RoomServer.InternalAPI.Listen = assignAddress()
cfg.SyncAPI.InternalAPI.Listen = assignAddress()
cfg.UserAPI.InternalAPI.Listen = assignAddress()
cfg.AppServiceAPI.InternalAPI.Connect = cfg.AppServiceAPI.InternalAPI.Listen
cfg.FederationAPI.InternalAPI.Connect = cfg.FederationAPI.InternalAPI.Listen
cfg.KeyServer.InternalAPI.Connect = cfg.KeyServer.InternalAPI.Listen
cfg.MediaAPI.InternalAPI.Connect = cfg.MediaAPI.InternalAPI.Listen
cfg.RoomServer.InternalAPI.Connect = cfg.RoomServer.InternalAPI.Listen
cfg.SyncAPI.InternalAPI.Connect = cfg.SyncAPI.InternalAPI.Listen
cfg.UserAPI.InternalAPI.Connect = cfg.UserAPI.InternalAPI.Listen
return &cfg, port, nil
}
// WriteConfig writes the config file to the directory.
func WriteConfig(cfg *config.Dendrite, configDir string) error {
data, err := yaml.Marshal(cfg)
if err != nil {
return err
}
return ioutil.WriteFile(filepath.Join(configDir, ConfigFile), data, 0666)
}
// NewMatrixKey generates a new ed25519 matrix server key and writes it to a file. // NewMatrixKey generates a new ed25519 matrix server key and writes it to a file.
func NewMatrixKey(matrixKeyPath string) (err error) { func NewMatrixKey(matrixKeyPath string) (err error) {
var data [35]byte var data [35]byte

View file

@ -15,7 +15,6 @@
package test package test
import ( import (
"crypto/ed25519"
"encoding/json" "encoding/json"
"fmt" "fmt"
"sync/atomic" "sync/atomic"
@ -35,12 +34,6 @@ var (
PresetTrustedPrivateChat Preset = 3 PresetTrustedPrivateChat Preset = 3
roomIDCounter = int64(0) roomIDCounter = int64(0)
testKeyID = gomatrixserverlib.KeyID("ed25519:test")
testPrivateKey = ed25519.NewKeyFromSeed([]byte{
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,
})
) )
type Room struct { type Room struct {
@ -49,22 +42,25 @@ type Room struct {
preset Preset preset Preset
creator *User creator *User
authEvents gomatrixserverlib.AuthEvents authEvents gomatrixserverlib.AuthEvents
events []*gomatrixserverlib.HeaderedEvent currentState map[string]*gomatrixserverlib.HeaderedEvent
events []*gomatrixserverlib.HeaderedEvent
} }
// Create a new test room. Automatically creates the initial create events. // Create a new test room. Automatically creates the initial create events.
func NewRoom(t *testing.T, creator *User, modifiers ...roomModifier) *Room { func NewRoom(t *testing.T, creator *User, modifiers ...roomModifier) *Room {
t.Helper() t.Helper()
counter := atomic.AddInt64(&roomIDCounter, 1) counter := atomic.AddInt64(&roomIDCounter, 1)
if creator.srvName == "" {
// set defaults then let roomModifiers override t.Fatalf("NewRoom: creator doesn't belong to a server: %+v", *creator)
}
r := &Room{ r := &Room{
ID: fmt.Sprintf("!%d:localhost", counter), ID: fmt.Sprintf("!%d:%s", counter, creator.srvName),
creator: creator, creator: creator,
authEvents: gomatrixserverlib.NewAuthEvents(nil), authEvents: gomatrixserverlib.NewAuthEvents(nil),
preset: PresetPublicChat, preset: PresetPublicChat,
Version: gomatrixserverlib.RoomVersionV9, Version: gomatrixserverlib.RoomVersionV9,
currentState: make(map[string]*gomatrixserverlib.HeaderedEvent),
} }
for _, m := range modifiers { for _, m := range modifiers {
m(t, r) m(t, r)
@ -73,6 +69,24 @@ func NewRoom(t *testing.T, creator *User, modifiers ...roomModifier) *Room {
return r return r
} }
func (r *Room) MustGetAuthEventRefsForEvent(t *testing.T, needed gomatrixserverlib.StateNeeded) []gomatrixserverlib.EventReference {
t.Helper()
a, err := needed.AuthEventReferences(&r.authEvents)
if err != nil {
t.Fatalf("MustGetAuthEvents: %v", err)
}
return a
}
func (r *Room) ForwardExtremities() []string {
if len(r.events) == 0 {
return nil
}
return []string{
r.events[len(r.events)-1].EventID(),
}
}
func (r *Room) insertCreateEvents(t *testing.T) { func (r *Room) insertCreateEvents(t *testing.T) {
t.Helper() t.Helper()
var joinRule gomatrixserverlib.JoinRuleContent var joinRule gomatrixserverlib.JoinRuleContent
@ -88,6 +102,7 @@ func (r *Room) insertCreateEvents(t *testing.T) {
joinRule.JoinRule = "public" joinRule.JoinRule = "public"
hisVis.HistoryVisibility = "shared" hisVis.HistoryVisibility = "shared"
} }
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomCreate, map[string]interface{}{ r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomCreate, map[string]interface{}{
"creator": r.creator.ID, "creator": r.creator.ID,
"room_version": r.Version, "room_version": r.Version,
@ -112,16 +127,16 @@ func (r *Room) CreateEvent(t *testing.T, creator *User, eventType string, conten
} }
if mod.privKey == nil { if mod.privKey == nil {
mod.privKey = testPrivateKey mod.privKey = creator.privKey
} }
if mod.keyID == "" { if mod.keyID == "" {
mod.keyID = testKeyID mod.keyID = creator.keyID
} }
if mod.originServerTS.IsZero() { if mod.originServerTS.IsZero() {
mod.originServerTS = time.Now() mod.originServerTS = time.Now()
} }
if mod.origin == "" { if mod.origin == "" {
mod.origin = gomatrixserverlib.ServerName("localhost") mod.origin = creator.srvName
} }
var unsigned gomatrixserverlib.RawJSON var unsigned gomatrixserverlib.RawJSON
@ -174,13 +189,14 @@ func (r *Room) CreateEvent(t *testing.T, creator *User, eventType string, conten
// Add a new event to this room DAG. Not thread-safe. // Add a new event to this room DAG. Not thread-safe.
func (r *Room) InsertEvent(t *testing.T, he *gomatrixserverlib.HeaderedEvent) { func (r *Room) InsertEvent(t *testing.T, he *gomatrixserverlib.HeaderedEvent) {
t.Helper() t.Helper()
// Add the event to the list of auth events // Add the event to the list of auth/state events
r.events = append(r.events, he) r.events = append(r.events, he)
if he.StateKey() != nil { if he.StateKey() != nil {
err := r.authEvents.AddEvent(he.Unwrap()) err := r.authEvents.AddEvent(he.Unwrap())
if err != nil { if err != nil {
t.Fatalf("InsertEvent: failed to add event to auth events: %s", err) t.Fatalf("InsertEvent: failed to add event to auth events: %s", err)
} }
r.currentState[he.Type()+" "+*he.StateKey()] = he
} }
} }
@ -188,6 +204,16 @@ func (r *Room) Events() []*gomatrixserverlib.HeaderedEvent {
return r.events return r.events
} }
func (r *Room) CurrentState() []*gomatrixserverlib.HeaderedEvent {
events := make([]*gomatrixserverlib.HeaderedEvent, len(r.currentState))
i := 0
for _, e := range r.currentState {
events[i] = e
i++
}
return events
}
func (r *Room) CreateAndInsert(t *testing.T, creator *User, eventType string, content interface{}, mods ...eventModifier) *gomatrixserverlib.HeaderedEvent { func (r *Room) CreateAndInsert(t *testing.T, creator *User, eventType string, content interface{}, mods ...eventModifier) *gomatrixserverlib.HeaderedEvent {
t.Helper() t.Helper()
he := r.CreateEvent(t, creator, eventType, content, mods...) he := r.CreateEvent(t, creator, eventType, content, mods...)

View file

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package test package testrig
import ( import (
"errors" "errors"
@ -24,22 +24,23 @@ import (
"github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test"
"github.com/nats-io/nats.go" "github.com/nats-io/nats.go"
) )
func CreateBaseDendrite(t *testing.T, dbType DBType) (*base.BaseDendrite, func()) { func CreateBaseDendrite(t *testing.T, dbType test.DBType) (*base.BaseDendrite, func()) {
var cfg config.Dendrite var cfg config.Dendrite
cfg.Defaults(false) cfg.Defaults(false)
cfg.Global.JetStream.InMemory = true cfg.Global.JetStream.InMemory = true
switch dbType { switch dbType {
case DBTypePostgres: case test.DBTypePostgres:
cfg.Global.Defaults(true) // autogen a signing key cfg.Global.Defaults(true) // autogen a signing key
cfg.MediaAPI.Defaults(true) // autogen a media path cfg.MediaAPI.Defaults(true) // autogen a media path
// use a distinct prefix else concurrent postgres/sqlite runs will clash since NATS will use // use a distinct prefix else concurrent postgres/sqlite runs will clash since NATS will use
// the file system event with InMemory=true :( // the file system event with InMemory=true :(
cfg.Global.JetStream.TopicPrefix = fmt.Sprintf("Test_%d_", dbType) cfg.Global.JetStream.TopicPrefix = fmt.Sprintf("Test_%d_", dbType)
connStr, close := PrepareDBConnectionString(t, dbType) connStr, close := test.PrepareDBConnectionString(t, dbType)
cfg.Global.DatabaseOptions = config.DatabaseOptions{ cfg.Global.DatabaseOptions = config.DatabaseOptions{
ConnectionString: config.DataSource(connStr), ConnectionString: config.DataSource(connStr),
MaxOpenConnections: 10, MaxOpenConnections: 10,
@ -47,7 +48,7 @@ func CreateBaseDendrite(t *testing.T, dbType DBType) (*base.BaseDendrite, func()
ConnMaxLifetimeSeconds: 60, ConnMaxLifetimeSeconds: 60,
} }
return base.NewBaseDendrite(&cfg, "Test", base.DisableMetrics), close return base.NewBaseDendrite(&cfg, "Test", base.DisableMetrics), close
case DBTypeSQLite: case test.DBTypeSQLite:
cfg.Defaults(true) // sets a sqlite db per component cfg.Defaults(true) // sets a sqlite db per component
// use a distinct prefix else concurrent postgres/sqlite runs will clash since NATS will use // use a distinct prefix else concurrent postgres/sqlite runs will clash since NATS will use
// the file system event with InMemory=true :( // the file system event with InMemory=true :(

View file

@ -1,4 +1,4 @@
package test package testrig
import ( import (
"encoding/json" "encoding/json"

View file

@ -15,22 +15,64 @@
package test package test
import ( import (
"crypto/ed25519"
"fmt" "fmt"
"sync/atomic" "sync/atomic"
"testing"
"github.com/matrix-org/gomatrixserverlib"
) )
var ( var (
userIDCounter = int64(0) userIDCounter = int64(0)
serverName = gomatrixserverlib.ServerName("test")
keyID = gomatrixserverlib.KeyID("ed25519:test")
privateKey = ed25519.NewKeyFromSeed([]byte{
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,
})
// private keys that tests can use
PrivateKeyA = ed25519.NewKeyFromSeed([]byte{
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 77,
})
PrivateKeyB = ed25519.NewKeyFromSeed([]byte{
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 66,
})
) )
type User struct { type User struct {
ID string ID string
// key ID and private key of the server who has this user, if known.
keyID gomatrixserverlib.KeyID
privKey ed25519.PrivateKey
srvName gomatrixserverlib.ServerName
} }
func NewUser() *User { type UserOpt func(*User)
counter := atomic.AddInt64(&userIDCounter, 1)
u := &User{ func WithSigningServer(srvName gomatrixserverlib.ServerName, keyID gomatrixserverlib.KeyID, privKey ed25519.PrivateKey) UserOpt {
ID: fmt.Sprintf("@%d:localhost", counter), return func(u *User) {
u.keyID = keyID
u.privKey = privKey
u.srvName = srvName
} }
return u }
func NewUser(t *testing.T, opts ...UserOpt) *User {
counter := atomic.AddInt64(&userIDCounter, 1)
var u User
for _, opt := range opts {
opt(&u)
}
if u.keyID == "" || u.srvName == "" || u.privKey == nil {
t.Logf("NewUser: missing signing server credentials; using default.")
WithSigningServer(serverName, keyID, privateKey)(&u)
}
u.ID = fmt.Sprintf("@%d:%s", counter, u.srvName)
t.Logf("NewUser: created user %s", u.ID)
return &u
} }

View file

@ -43,7 +43,7 @@ func Test_AccountData(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateDatabase(t, dbType) db, close := mustCreateDatabase(t, dbType)
defer close() defer close()
alice := test.NewUser() alice := test.NewUser(t)
localpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) localpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
assert.NoError(t, err) assert.NoError(t, err)
@ -74,7 +74,7 @@ func Test_Accounts(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateDatabase(t, dbType) db, close := mustCreateDatabase(t, dbType)
defer close() defer close()
alice := test.NewUser() alice := test.NewUser(t)
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
assert.NoError(t, err) assert.NoError(t, err)
@ -128,7 +128,7 @@ func Test_Accounts(t *testing.T) {
} }
func Test_Devices(t *testing.T) { func Test_Devices(t *testing.T) {
alice := test.NewUser() alice := test.NewUser(t)
localpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) localpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
assert.NoError(t, err) assert.NoError(t, err)
deviceID := util.RandomString(8) deviceID := util.RandomString(8)
@ -212,7 +212,7 @@ func Test_Devices(t *testing.T) {
} }
func Test_KeyBackup(t *testing.T) { func Test_KeyBackup(t *testing.T) {
alice := test.NewUser() alice := test.NewUser(t)
room := test.NewRoom(t, alice) room := test.NewRoom(t, alice)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
@ -291,7 +291,7 @@ func Test_KeyBackup(t *testing.T) {
} }
func Test_LoginToken(t *testing.T) { func Test_LoginToken(t *testing.T) {
alice := test.NewUser() alice := test.NewUser(t)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateDatabase(t, dbType) db, close := mustCreateDatabase(t, dbType)
defer close() defer close()
@ -321,7 +321,7 @@ func Test_LoginToken(t *testing.T) {
} }
func Test_OpenID(t *testing.T) { func Test_OpenID(t *testing.T) {
alice := test.NewUser() alice := test.NewUser(t)
token := util.RandomString(24) token := util.RandomString(24)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
@ -341,7 +341,7 @@ func Test_OpenID(t *testing.T) {
} }
func Test_Profile(t *testing.T) { func Test_Profile(t *testing.T) {
alice := test.NewUser() alice := test.NewUser(t)
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
assert.NoError(t, err) assert.NoError(t, err)
@ -379,7 +379,7 @@ func Test_Profile(t *testing.T) {
} }
func Test_Pusher(t *testing.T) { func Test_Pusher(t *testing.T) {
alice := test.NewUser() alice := test.NewUser(t)
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
assert.NoError(t, err) assert.NoError(t, err)
@ -430,7 +430,7 @@ func Test_Pusher(t *testing.T) {
} }
func Test_ThreePID(t *testing.T) { func Test_ThreePID(t *testing.T) {
alice := test.NewUser() alice := test.NewUser(t)
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
assert.NoError(t, err) assert.NoError(t, err)
@ -467,7 +467,7 @@ func Test_ThreePID(t *testing.T) {
} }
func Test_Notification(t *testing.T) { func Test_Notification(t *testing.T) {
alice := test.NewUser() alice := test.NewUser(t)
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
assert.NoError(t, err) assert.NoError(t, err)
room := test.NewRoom(t, alice) room := test.NewRoom(t, alice)

View file

@ -24,7 +24,6 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/httputil"
internalTest "github.com/matrix-org/dendrite/internal/test"
"github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/userapi" "github.com/matrix-org/dendrite/userapi"
"github.com/matrix-org/dendrite/userapi/inthttp" "github.com/matrix-org/dendrite/userapi/inthttp"
@ -135,7 +134,7 @@ func TestQueryProfile(t *testing.T) {
t.Run("HTTP API", func(t *testing.T) { t.Run("HTTP API", func(t *testing.T) {
router := mux.NewRouter().PathPrefix(httputil.InternalPathPrefix).Subrouter() router := mux.NewRouter().PathPrefix(httputil.InternalPathPrefix).Subrouter()
userapi.AddInternalRoutes(router, userAPI) userapi.AddInternalRoutes(router, userAPI)
apiURL, cancel := internalTest.ListenAndServe(t, router, false) apiURL, cancel := test.ListenAndServe(t, router, false)
defer cancel() defer cancel()
httpAPI, err := inthttp.NewUserAPIClient(apiURL, &http.Client{}) httpAPI, err := inthttp.NewUserAPIClient(apiURL, &http.Client{})
if err != nil { if err != nil {