mirror of
https://github.com/matrix-org/dendrite.git
synced 2026-01-06 13:43:09 -06:00
Merge branch 'main' into s7evink/expireedus
This commit is contained in:
commit
ca463f3142
|
|
@ -20,7 +20,7 @@ import (
|
|||
"log"
|
||||
"os"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/test"
|
||||
"github.com/matrix-org/dendrite/test"
|
||||
)
|
||||
|
||||
const usage = `Usage: %s
|
||||
|
|
|
|||
84
docs/coverage.md
Normal file
84
docs/coverage.md
Normal file
|
|
@ -0,0 +1,84 @@
|
|||
---
|
||||
title: Coverage
|
||||
parent: Development
|
||||
permalink: /development/coverage
|
||||
---
|
||||
|
||||
To generate a test coverage report for Sytest, a small patch needs to be applied to the Sytest repository to compile and use the instrumented binary:
|
||||
```patch
|
||||
diff --git a/lib/SyTest/Homeserver/Dendrite.pm b/lib/SyTest/Homeserver/Dendrite.pm
|
||||
index 8f0e209c..ad057e52 100644
|
||||
--- a/lib/SyTest/Homeserver/Dendrite.pm
|
||||
+++ b/lib/SyTest/Homeserver/Dendrite.pm
|
||||
@@ -337,7 +337,7 @@ sub _start_monolith
|
||||
|
||||
$output->diag( "Starting monolith server" );
|
||||
my @command = (
|
||||
- $self->{bindir} . '/dendrite-monolith-server',
|
||||
+ $self->{bindir} . '/dendrite-monolith-server', '--test.coverprofile=' . $self->{hs_dir} . '/integrationcover.log', "DEVEL",
|
||||
'--config', $self->{paths}{config},
|
||||
'--http-bind-address', $self->{bind_host} . ':' . $self->unsecure_port,
|
||||
'--https-bind-address', $self->{bind_host} . ':' . $self->secure_port,
|
||||
diff --git a/scripts/dendrite_sytest.sh b/scripts/dendrite_sytest.sh
|
||||
index f009332b..7ea79869 100755
|
||||
--- a/scripts/dendrite_sytest.sh
|
||||
+++ b/scripts/dendrite_sytest.sh
|
||||
@@ -34,7 +34,8 @@ export GOBIN=/tmp/bin
|
||||
echo >&2 "--- Building dendrite from source"
|
||||
cd /src
|
||||
mkdir -p $GOBIN
|
||||
-go install -v ./cmd/dendrite-monolith-server
|
||||
+# go install -v ./cmd/dendrite-monolith-server
|
||||
+go test -c -cover -covermode=atomic -o $GOBIN/dendrite-monolith-server -coverpkg "github.com/matrix-org/..." ./cmd/dendrite-monolith-server
|
||||
go install -v ./cmd/generate-keys
|
||||
cd -
|
||||
```
|
||||
|
||||
Then run Sytest. This will generate a new file `integrationcover.log` in each server's directory e.g `server-0/integrationcover.log`. To parse it,
|
||||
ensure your working directory is under the Dendrite repository then run:
|
||||
```bash
|
||||
go tool cover -func=/path/to/server-0/integrationcover.log
|
||||
```
|
||||
which will produce an output like:
|
||||
```
|
||||
...
|
||||
github.com/matrix-org/util/json.go:83: NewJSONRequestHandler 100.0%
|
||||
github.com/matrix-org/util/json.go:90: Protect 57.1%
|
||||
github.com/matrix-org/util/json.go:110: RequestWithLogging 100.0%
|
||||
github.com/matrix-org/util/json.go:132: MakeJSONAPI 70.0%
|
||||
github.com/matrix-org/util/json.go:151: respond 61.5%
|
||||
github.com/matrix-org/util/json.go:180: WithCORSOptions 0.0%
|
||||
github.com/matrix-org/util/json.go:191: SetCORSHeaders 100.0%
|
||||
github.com/matrix-org/util/json.go:202: RandomString 100.0%
|
||||
github.com/matrix-org/util/json.go:210: init 100.0%
|
||||
github.com/matrix-org/util/unique.go:13: Unique 91.7%
|
||||
github.com/matrix-org/util/unique.go:48: SortAndUnique 100.0%
|
||||
github.com/matrix-org/util/unique.go:55: UniqueStrings 100.0%
|
||||
total: (statements) 53.7%
|
||||
```
|
||||
The total coverage for this run is the last line at the bottom. However, this value is misleading because Dendrite can run in many different configurations,
|
||||
which will never be tested in a single test run (e.g sqlite or postgres, monolith or polylith). To get a more accurate value, additional processing is required
|
||||
to remove packages which will never be tested and extension MSCs:
|
||||
```bash
|
||||
# These commands are all similar but change which package paths are _removed_ from the output.
|
||||
|
||||
# For Postgres (monolith)
|
||||
go tool cover -func=/path/to/server-0/integrationcover.log | grep 'github.com/matrix-org/dendrite' | grep -Ev 'inthttp|sqlite|setup/mscs|api_trace' > coverage.txt
|
||||
|
||||
# For Postgres (polylith)
|
||||
go tool cover -func=/path/to/server-0/integrationcover.log | grep 'github.com/matrix-org/dendrite' | grep -Ev 'sqlite|setup/mscs|api_trace' > coverage.txt
|
||||
|
||||
# For SQLite (monolith)
|
||||
go tool cover -func=/path/to/server-0/integrationcover.log | grep 'github.com/matrix-org/dendrite' | grep -Ev 'inthttp|postgres|setup/mscs|api_trace' > coverage.txt
|
||||
|
||||
# For SQLite (polylith)
|
||||
go tool cover -func=/path/to/server-0/integrationcover.log | grep 'github.com/matrix-org/dendrite' | grep -Ev 'postgres|setup/mscs|api_trace' > coverage.txt
|
||||
```
|
||||
|
||||
A total value can then be calculated using:
|
||||
```bash
|
||||
cat coverage.txt | awk -F '\t+' '{x = x + $3} END {print x/NR}'
|
||||
```
|
||||
|
||||
|
||||
We currently do not have a way to combine Sytest/Complement/Unit Tests into a single coverage report.
|
||||
|
|
@ -12,12 +12,16 @@ import (
|
|||
|
||||
// FederationInternalAPI is used to query information from the federation sender.
|
||||
type FederationInternalAPI interface {
|
||||
FederationClient
|
||||
gomatrixserverlib.FederatedStateClient
|
||||
KeyserverFederationAPI
|
||||
gomatrixserverlib.KeyDatabase
|
||||
ClientFederationAPI
|
||||
RoomserverFederationAPI
|
||||
|
||||
QueryServerKeys(ctx context.Context, request *QueryServerKeysRequest, response *QueryServerKeysResponse) error
|
||||
LookupServerKeys(ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) ([]gomatrixserverlib.ServerKeys, error)
|
||||
MSC2836EventRelationships(ctx context.Context, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error)
|
||||
MSC2946Spaces(ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool) (res gomatrixserverlib.MSC2946SpacesResponse, err error)
|
||||
|
||||
// Broadcasts an EDU to all servers in rooms we are joined to. Used in the yggdrasil demos.
|
||||
PerformBroadcastEDU(
|
||||
|
|
@ -60,17 +64,43 @@ type RoomserverFederationAPI interface {
|
|||
LookupMissingEvents(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error)
|
||||
}
|
||||
|
||||
// FederationClient is a subset of gomatrixserverlib.FederationClient functions which the fedsender
|
||||
// KeyserverFederationAPI is a subset of gomatrixserverlib.FederationClient functions which the keyserver
|
||||
// implements as proxy calls, with built-in backoff/retries/etc. Errors returned from functions in
|
||||
// this interface are of type FederationClientError
|
||||
type FederationClient interface {
|
||||
gomatrixserverlib.FederatedStateClient
|
||||
type KeyserverFederationAPI interface {
|
||||
GetUserDevices(ctx context.Context, s gomatrixserverlib.ServerName, userID string) (res gomatrixserverlib.RespUserDevices, err error)
|
||||
ClaimKeys(ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (res gomatrixserverlib.RespClaimKeys, err error)
|
||||
QueryKeys(ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string) (res gomatrixserverlib.RespQueryKeys, err error)
|
||||
}
|
||||
|
||||
// an interface for gmsl.FederationClient - contains functions called by federationapi only.
|
||||
type FederationClient interface {
|
||||
gomatrixserverlib.KeyClient
|
||||
SendTransaction(ctx context.Context, t gomatrixserverlib.Transaction) (res gomatrixserverlib.RespSend, err error)
|
||||
|
||||
// Perform operations
|
||||
LookupRoomAlias(ctx context.Context, s gomatrixserverlib.ServerName, roomAlias string) (res gomatrixserverlib.RespDirectory, err error)
|
||||
Peek(ctx context.Context, s gomatrixserverlib.ServerName, roomID, peekID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespPeek, err error)
|
||||
MakeJoin(ctx context.Context, s gomatrixserverlib.ServerName, roomID, userID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMakeJoin, err error)
|
||||
SendJoin(ctx context.Context, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (res gomatrixserverlib.RespSendJoin, err error)
|
||||
MakeLeave(ctx context.Context, s gomatrixserverlib.ServerName, roomID, userID string) (res gomatrixserverlib.RespMakeLeave, err error)
|
||||
SendLeave(ctx context.Context, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (err error)
|
||||
SendInviteV2(ctx context.Context, s gomatrixserverlib.ServerName, request gomatrixserverlib.InviteV2Request) (res gomatrixserverlib.RespInviteV2, err error)
|
||||
|
||||
GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error)
|
||||
|
||||
GetEventAuth(ctx context.Context, s gomatrixserverlib.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (res gomatrixserverlib.RespEventAuth, err error)
|
||||
GetUserDevices(ctx context.Context, s gomatrixserverlib.ServerName, userID string) (gomatrixserverlib.RespUserDevices, error)
|
||||
ClaimKeys(ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (gomatrixserverlib.RespClaimKeys, error)
|
||||
QueryKeys(ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string) (gomatrixserverlib.RespQueryKeys, error)
|
||||
Backfill(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string) (res gomatrixserverlib.Transaction, err error)
|
||||
MSC2836EventRelationships(ctx context.Context, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error)
|
||||
MSC2946Spaces(ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool) (res gomatrixserverlib.MSC2946SpacesResponse, err error)
|
||||
LookupServerKeys(ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) ([]gomatrixserverlib.ServerKeys, error)
|
||||
|
||||
ExchangeThirdPartyInvite(ctx context.Context, s gomatrixserverlib.ServerName, builder gomatrixserverlib.EventBuilder) (err error)
|
||||
LookupState(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespState, err error)
|
||||
LookupStateIDs(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error)
|
||||
LookupMissingEvents(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error)
|
||||
}
|
||||
|
||||
// FederationClientError is returned from FederationClient methods in the event of a problem.
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ type KeyChangeConsumer struct {
|
|||
db storage.Database
|
||||
queues *queue.OutgoingQueues
|
||||
serverName gomatrixserverlib.ServerName
|
||||
rsAPI roomserverAPI.RoomserverInternalAPI
|
||||
rsAPI roomserverAPI.FederationRoomserverAPI
|
||||
topic string
|
||||
}
|
||||
|
||||
|
|
@ -50,7 +50,7 @@ func NewKeyChangeConsumer(
|
|||
js nats.JetStreamContext,
|
||||
queues *queue.OutgoingQueues,
|
||||
store storage.Database,
|
||||
rsAPI roomserverAPI.RoomserverInternalAPI,
|
||||
rsAPI roomserverAPI.FederationRoomserverAPI,
|
||||
) *KeyChangeConsumer {
|
||||
return &KeyChangeConsumer{
|
||||
ctx: process.Context(),
|
||||
|
|
@ -120,6 +120,7 @@ func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) bool {
|
|||
logger.WithError(err).Error("failed to calculate joined rooms for user")
|
||||
return true
|
||||
}
|
||||
|
||||
// send this key change to all servers who share rooms with this user.
|
||||
destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true)
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ import (
|
|||
type OutputRoomEventConsumer struct {
|
||||
ctx context.Context
|
||||
cfg *config.FederationAPI
|
||||
rsAPI api.RoomserverInternalAPI
|
||||
rsAPI api.FederationRoomserverAPI
|
||||
jetstream nats.JetStreamContext
|
||||
durable string
|
||||
db storage.Database
|
||||
|
|
@ -51,7 +51,7 @@ func NewOutputRoomEventConsumer(
|
|||
js nats.JetStreamContext,
|
||||
queues *queue.OutgoingQueues,
|
||||
store storage.Database,
|
||||
rsAPI api.RoomserverInternalAPI,
|
||||
rsAPI api.FederationRoomserverAPI,
|
||||
) *OutputRoomEventConsumer {
|
||||
return &OutputRoomEventConsumer{
|
||||
ctx: process.Context(),
|
||||
|
|
@ -89,15 +89,7 @@ func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msg *nats.Msg)
|
|||
switch output.Type {
|
||||
case api.OutputTypeNewRoomEvent:
|
||||
ev := output.NewRoomEvent.Event
|
||||
|
||||
if output.NewRoomEvent.RewritesState {
|
||||
if err := s.db.PurgeRoomState(s.ctx, ev.RoomID()); err != nil {
|
||||
log.WithError(err).Errorf("roomserver output log: purge room state failure")
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.processMessage(*output.NewRoomEvent); err != nil {
|
||||
if err := s.processMessage(*output.NewRoomEvent, output.NewRoomEvent.RewritesState); err != nil {
|
||||
// panic rather than continue with an inconsistent database
|
||||
log.WithFields(log.Fields{
|
||||
"event_id": ev.EventID(),
|
||||
|
|
@ -145,7 +137,7 @@ func (s *OutputRoomEventConsumer) processInboundPeek(orp api.OutputNewInboundPee
|
|||
|
||||
// processMessage updates the list of currently joined hosts in the room
|
||||
// and then sends the event to the hosts that were joined before the event.
|
||||
func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent) error {
|
||||
func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent, rewritesState bool) error {
|
||||
addsStateEvents, missingEventIDs := ore.NeededStateEventIDs()
|
||||
|
||||
// Ask the roomserver and add in the rest of the results into the set.
|
||||
|
|
@ -164,7 +156,7 @@ func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent) err
|
|||
addsStateEvents = append(addsStateEvents, eventsRes.Events...)
|
||||
}
|
||||
|
||||
addsJoinedHosts, err := joinedHostsFromEvents(gomatrixserverlib.UnwrapEventHeaders(addsStateEvents))
|
||||
addsJoinedHosts, err := JoinedHostsFromEvents(gomatrixserverlib.UnwrapEventHeaders(addsStateEvents))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -176,10 +168,9 @@ func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent) err
|
|||
oldJoinedHosts, err := s.db.UpdateRoom(
|
||||
s.ctx,
|
||||
ore.Event.RoomID(),
|
||||
ore.LastSentEventID,
|
||||
ore.Event.EventID(),
|
||||
addsJoinedHosts,
|
||||
ore.RemovesStateEventIDs,
|
||||
rewritesState, // if we're re-writing state, nuke all joined hosts before adding
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -238,7 +229,7 @@ func (s *OutputRoomEventConsumer) joinedHostsAtEvent(
|
|||
return nil, err
|
||||
}
|
||||
|
||||
combinedAddsJoinedHosts, err := joinedHostsFromEvents(combinedAddsEvents)
|
||||
combinedAddsJoinedHosts, err := JoinedHostsFromEvents(combinedAddsEvents)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -284,10 +275,10 @@ func (s *OutputRoomEventConsumer) joinedHostsAtEvent(
|
|||
return result, nil
|
||||
}
|
||||
|
||||
// joinedHostsFromEvents turns a list of state events into a list of joined hosts.
|
||||
// JoinedHostsFromEvents turns a list of state events into a list of joined hosts.
|
||||
// This errors if one of the events was invalid.
|
||||
// It should be impossible for an invalid event to get this far in the pipeline.
|
||||
func joinedHostsFromEvents(evs []*gomatrixserverlib.Event) ([]types.JoinedHost, error) {
|
||||
func JoinedHostsFromEvents(evs []*gomatrixserverlib.Event) ([]types.JoinedHost, error) {
|
||||
var joinedHosts []types.JoinedHost
|
||||
for _, ev := range evs {
|
||||
if ev.Type() != "m.room.member" || ev.StateKey() == nil {
|
||||
|
|
|
|||
|
|
@ -95,8 +95,8 @@ func AddPublicRoutes(
|
|||
// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes.
|
||||
func NewInternalAPI(
|
||||
base *base.BaseDendrite,
|
||||
federation *gomatrixserverlib.FederationClient,
|
||||
rsAPI roomserverAPI.RoomserverInternalAPI,
|
||||
federation api.FederationClient,
|
||||
rsAPI roomserverAPI.FederationRoomserverAPI,
|
||||
caches *caching.Caches,
|
||||
keyRing *gomatrixserverlib.KeyRing,
|
||||
resetBlacklist bool,
|
||||
|
|
|
|||
|
|
@ -3,18 +3,250 @@ package federationapi_test
|
|||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/federationapi"
|
||||
"github.com/matrix-org/dendrite/federationapi/api"
|
||||
"github.com/matrix-org/dendrite/federationapi/internal"
|
||||
"github.com/matrix-org/dendrite/internal/test"
|
||||
keyapi "github.com/matrix-org/dendrite/keyserver/api"
|
||||
rsapi "github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/dendrite/setup/base"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/setup/jetstream"
|
||||
"github.com/matrix-org/dendrite/test"
|
||||
"github.com/matrix-org/dendrite/test/testrig"
|
||||
"github.com/matrix-org/gomatrix"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/nats-io/nats.go"
|
||||
)
|
||||
|
||||
type fedRoomserverAPI struct {
|
||||
rsapi.FederationRoomserverAPI
|
||||
inputRoomEvents func(ctx context.Context, req *rsapi.InputRoomEventsRequest, res *rsapi.InputRoomEventsResponse)
|
||||
queryRoomsForUser func(ctx context.Context, req *rsapi.QueryRoomsForUserRequest, res *rsapi.QueryRoomsForUserResponse) error
|
||||
}
|
||||
|
||||
// PerformJoin will call this function
|
||||
func (f *fedRoomserverAPI) InputRoomEvents(ctx context.Context, req *rsapi.InputRoomEventsRequest, res *rsapi.InputRoomEventsResponse) {
|
||||
if f.inputRoomEvents == nil {
|
||||
return
|
||||
}
|
||||
f.inputRoomEvents(ctx, req, res)
|
||||
}
|
||||
|
||||
// keychange consumer calls this
|
||||
func (f *fedRoomserverAPI) QueryRoomsForUser(ctx context.Context, req *rsapi.QueryRoomsForUserRequest, res *rsapi.QueryRoomsForUserResponse) error {
|
||||
if f.queryRoomsForUser == nil {
|
||||
return nil
|
||||
}
|
||||
return f.queryRoomsForUser(ctx, req, res)
|
||||
}
|
||||
|
||||
// TODO: This struct isn't generic, only works for TestFederationAPIJoinThenKeyUpdate
|
||||
type fedClient struct {
|
||||
api.FederationClient
|
||||
allowJoins []*test.Room
|
||||
keys map[gomatrixserverlib.ServerName]struct {
|
||||
key ed25519.PrivateKey
|
||||
keyID gomatrixserverlib.KeyID
|
||||
}
|
||||
t *testing.T
|
||||
sentTxn bool
|
||||
}
|
||||
|
||||
func (f *fedClient) GetServerKeys(ctx context.Context, matrixServer gomatrixserverlib.ServerName) (gomatrixserverlib.ServerKeys, error) {
|
||||
fmt.Println("GetServerKeys:", matrixServer)
|
||||
var keys gomatrixserverlib.ServerKeys
|
||||
var keyID gomatrixserverlib.KeyID
|
||||
var pkey ed25519.PrivateKey
|
||||
for srv, data := range f.keys {
|
||||
if srv == matrixServer {
|
||||
pkey = data.key
|
||||
keyID = data.keyID
|
||||
break
|
||||
}
|
||||
}
|
||||
if pkey == nil {
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
keys.ServerName = matrixServer
|
||||
keys.ValidUntilTS = gomatrixserverlib.AsTimestamp(time.Now().Add(10 * time.Hour))
|
||||
publicKey := pkey.Public().(ed25519.PublicKey)
|
||||
keys.VerifyKeys = map[gomatrixserverlib.KeyID]gomatrixserverlib.VerifyKey{
|
||||
keyID: {
|
||||
Key: gomatrixserverlib.Base64Bytes(publicKey),
|
||||
},
|
||||
}
|
||||
toSign, err := json.Marshal(keys.ServerKeyFields)
|
||||
if err != nil {
|
||||
return keys, err
|
||||
}
|
||||
|
||||
keys.Raw, err = gomatrixserverlib.SignJSON(
|
||||
string(matrixServer), keyID, pkey, toSign,
|
||||
)
|
||||
if err != nil {
|
||||
return keys, err
|
||||
}
|
||||
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
func (f *fedClient) MakeJoin(ctx context.Context, s gomatrixserverlib.ServerName, roomID, userID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMakeJoin, err error) {
|
||||
for _, r := range f.allowJoins {
|
||||
if r.ID == roomID {
|
||||
res.RoomVersion = r.Version
|
||||
res.JoinEvent = gomatrixserverlib.EventBuilder{
|
||||
Sender: userID,
|
||||
RoomID: roomID,
|
||||
Type: "m.room.member",
|
||||
StateKey: &userID,
|
||||
Content: gomatrixserverlib.RawJSON([]byte(`{"membership":"join"}`)),
|
||||
PrevEvents: r.ForwardExtremities(),
|
||||
}
|
||||
var needed gomatrixserverlib.StateNeeded
|
||||
needed, err = gomatrixserverlib.StateNeededForEventBuilder(&res.JoinEvent)
|
||||
if err != nil {
|
||||
f.t.Errorf("StateNeededForEventBuilder: %v", err)
|
||||
return
|
||||
}
|
||||
res.JoinEvent.AuthEvents = r.MustGetAuthEventRefsForEvent(f.t, needed)
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
func (f *fedClient) SendJoin(ctx context.Context, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (res gomatrixserverlib.RespSendJoin, err error) {
|
||||
for _, r := range f.allowJoins {
|
||||
if r.ID == event.RoomID() {
|
||||
r.InsertEvent(f.t, event.Headered(r.Version))
|
||||
f.t.Logf("Join event: %v", event.EventID())
|
||||
res.StateEvents = gomatrixserverlib.NewEventJSONsFromHeaderedEvents(r.CurrentState())
|
||||
res.AuthEvents = gomatrixserverlib.NewEventJSONsFromHeaderedEvents(r.Events())
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (f *fedClient) SendTransaction(ctx context.Context, t gomatrixserverlib.Transaction) (res gomatrixserverlib.RespSend, err error) {
|
||||
for _, edu := range t.EDUs {
|
||||
if edu.Type == gomatrixserverlib.MDeviceListUpdate {
|
||||
f.sentTxn = true
|
||||
}
|
||||
}
|
||||
f.t.Logf("got /send")
|
||||
return
|
||||
}
|
||||
|
||||
// Regression test to make sure that /send_join is updating the destination hosts synchronously and
|
||||
// isn't relying on the roomserver.
|
||||
func TestFederationAPIJoinThenKeyUpdate(t *testing.T) {
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
testFederationAPIJoinThenKeyUpdate(t, dbType)
|
||||
})
|
||||
}
|
||||
|
||||
func testFederationAPIJoinThenKeyUpdate(t *testing.T, dbType test.DBType) {
|
||||
base, close := testrig.CreateBaseDendrite(t, dbType)
|
||||
base.Cfg.FederationAPI.PreferDirectFetch = true
|
||||
defer close()
|
||||
jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream)
|
||||
defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream)
|
||||
|
||||
serverA := gomatrixserverlib.ServerName("server.a")
|
||||
serverAKeyID := gomatrixserverlib.KeyID("ed25519:servera")
|
||||
serverAPrivKey := test.PrivateKeyA
|
||||
creator := test.NewUser(t, test.WithSigningServer(serverA, serverAKeyID, serverAPrivKey))
|
||||
|
||||
myServer := base.Cfg.Global.ServerName
|
||||
myServerKeyID := base.Cfg.Global.KeyID
|
||||
myServerPrivKey := base.Cfg.Global.PrivateKey
|
||||
joiningUser := test.NewUser(t, test.WithSigningServer(myServer, myServerKeyID, myServerPrivKey))
|
||||
fmt.Printf("creator: %v joining user: %v\n", creator.ID, joiningUser.ID)
|
||||
room := test.NewRoom(t, creator)
|
||||
|
||||
rsapi := &fedRoomserverAPI{
|
||||
inputRoomEvents: func(ctx context.Context, req *rsapi.InputRoomEventsRequest, res *rsapi.InputRoomEventsResponse) {
|
||||
if req.Asynchronous {
|
||||
t.Errorf("InputRoomEvents from PerformJoin MUST be synchronous")
|
||||
}
|
||||
},
|
||||
queryRoomsForUser: func(ctx context.Context, req *rsapi.QueryRoomsForUserRequest, res *rsapi.QueryRoomsForUserResponse) error {
|
||||
if req.UserID == joiningUser.ID && req.WantMembership == "join" {
|
||||
res.RoomIDs = []string{room.ID}
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("unexpected queryRoomsForUser: %+v", *req)
|
||||
},
|
||||
}
|
||||
fc := &fedClient{
|
||||
allowJoins: []*test.Room{room},
|
||||
t: t,
|
||||
keys: map[gomatrixserverlib.ServerName]struct {
|
||||
key ed25519.PrivateKey
|
||||
keyID gomatrixserverlib.KeyID
|
||||
}{
|
||||
serverA: {
|
||||
key: serverAPrivKey,
|
||||
keyID: serverAKeyID,
|
||||
},
|
||||
myServer: {
|
||||
key: myServerPrivKey,
|
||||
keyID: myServerKeyID,
|
||||
},
|
||||
},
|
||||
}
|
||||
fsapi := federationapi.NewInternalAPI(base, fc, rsapi, base.Caches, nil, false)
|
||||
|
||||
var resp api.PerformJoinResponse
|
||||
fsapi.PerformJoin(context.Background(), &api.PerformJoinRequest{
|
||||
RoomID: room.ID,
|
||||
UserID: joiningUser.ID,
|
||||
ServerNames: []gomatrixserverlib.ServerName{serverA},
|
||||
}, &resp)
|
||||
if resp.JoinedVia != serverA {
|
||||
t.Errorf("PerformJoin: joined via %v want %v", resp.JoinedVia, serverA)
|
||||
}
|
||||
if resp.LastError != nil {
|
||||
t.Fatalf("PerformJoin: returned error: %+v", *resp.LastError)
|
||||
}
|
||||
|
||||
// Inject a keyserver key change event and ensure we try to send it out. If we don't, then the
|
||||
// federationapi is incorrectly waiting for an output room event to arrive to update the joined
|
||||
// hosts table.
|
||||
key := keyapi.DeviceMessage{
|
||||
Type: keyapi.TypeDeviceKeyUpdate,
|
||||
DeviceKeys: &keyapi.DeviceKeys{
|
||||
UserID: joiningUser.ID,
|
||||
DeviceID: "MY_DEVICE",
|
||||
DisplayName: "BLARGLE",
|
||||
KeyJSON: []byte(`{}`),
|
||||
},
|
||||
}
|
||||
b, err := json.Marshal(key)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal device message: %s", err)
|
||||
}
|
||||
|
||||
msg := &nats.Msg{
|
||||
Subject: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputKeyChangeEvent),
|
||||
Header: nats.Header{},
|
||||
Data: b,
|
||||
}
|
||||
msg.Header.Set(jetstream.UserID, key.UserID)
|
||||
|
||||
testrig.MustPublishMsgs(t, jsctx, msg)
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
if !fc.sentTxn {
|
||||
t.Fatalf("did not send device list update")
|
||||
}
|
||||
}
|
||||
|
||||
// Tests that event IDs with '/' in them (escaped as %2F) are correctly passed to the right handler and don't 404.
|
||||
// Relevant for v3 rooms and a cause of flakey sytests as the IDs are randomly generated.
|
||||
func TestRoomsV3URLEscapeDoNot404(t *testing.T) {
|
||||
|
|
@ -86,7 +318,7 @@ func TestRoomsV3URLEscapeDoNot404(t *testing.T) {
|
|||
}
|
||||
gerr, ok := err.(gomatrix.HTTPError)
|
||||
if !ok {
|
||||
t.Errorf("failed to cast response error as gomatrix.HTTPError")
|
||||
t.Errorf("failed to cast response error as gomatrix.HTTPError: %s", err)
|
||||
continue
|
||||
}
|
||||
t.Logf("Error: %+v", gerr)
|
||||
|
|
|
|||
|
|
@ -25,8 +25,8 @@ type FederationInternalAPI struct {
|
|||
db storage.Database
|
||||
cfg *config.FederationAPI
|
||||
statistics *statistics.Statistics
|
||||
rsAPI roomserverAPI.RoomserverInternalAPI
|
||||
federation *gomatrixserverlib.FederationClient
|
||||
rsAPI roomserverAPI.FederationRoomserverAPI
|
||||
federation api.FederationClient
|
||||
keyRing *gomatrixserverlib.KeyRing
|
||||
queues *queue.OutgoingQueues
|
||||
joins sync.Map // joins currently in progress
|
||||
|
|
@ -34,8 +34,8 @@ type FederationInternalAPI struct {
|
|||
|
||||
func NewFederationInternalAPI(
|
||||
db storage.Database, cfg *config.FederationAPI,
|
||||
rsAPI roomserverAPI.RoomserverInternalAPI,
|
||||
federation *gomatrixserverlib.FederationClient,
|
||||
rsAPI roomserverAPI.FederationRoomserverAPI,
|
||||
federation api.FederationClient,
|
||||
statistics *statistics.Statistics,
|
||||
caches *caching.Caches,
|
||||
queues *queue.OutgoingQueues,
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/federationapi/api"
|
||||
"github.com/matrix-org/dendrite/federationapi/consumers"
|
||||
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/dendrite/roomserver/version"
|
||||
"github.com/matrix-org/gomatrix"
|
||||
|
|
@ -235,6 +236,21 @@ func (r *FederationInternalAPI) performJoinUsingServer(
|
|||
return fmt.Errorf("respSendJoin.Check: %w", err)
|
||||
}
|
||||
|
||||
// We need to immediately update our list of joined hosts for this room now as we are technically
|
||||
// joined. We must do this synchronously: we cannot rely on the roomserver output events as they
|
||||
// will happen asyncly. If we don't update this table, you can end up with bad failure modes like
|
||||
// joining a room, waiting for 200 OK then changing device keys and have those keys not be sent
|
||||
// to other servers (this was a cause of a flakey sytest "Local device key changes get to remote servers")
|
||||
// The events are trusted now as we performed auth checks above.
|
||||
joinedHosts, err := consumers.JoinedHostsFromEvents(respState.StateEvents.TrustedEvents(respMakeJoin.RoomVersion, false))
|
||||
if err != nil {
|
||||
return fmt.Errorf("JoinedHostsFromEvents: failed to get joined hosts: %s", err)
|
||||
}
|
||||
logrus.WithField("hosts", joinedHosts).WithField("room", roomID).Info("Joined federated room with hosts")
|
||||
if _, err = r.db.UpdateRoom(context.Background(), roomID, joinedHosts, nil, true); err != nil {
|
||||
return fmt.Errorf("UpdatedRoom: failed to update room with joined hosts: %s", err)
|
||||
}
|
||||
|
||||
// If we successfully performed a send_join above then the other
|
||||
// server now thinks we're a part of the room. Send the newly
|
||||
// returned state to the roomserver to update our local view.
|
||||
|
|
@ -650,7 +666,7 @@ func setDefaultRoomVersionFromJoinEvent(joinEvent gomatrixserverlib.EventBuilder
|
|||
|
||||
// FederatedAuthProvider is an auth chain provider which fetches events from the server provided
|
||||
func federatedAuthProvider(
|
||||
ctx context.Context, federation *gomatrixserverlib.FederationClient,
|
||||
ctx context.Context, federation api.FederationClient,
|
||||
keyRing gomatrixserverlib.JSONVerifier, server gomatrixserverlib.ServerName,
|
||||
) gomatrixserverlib.AuthChainProvider {
|
||||
// A list of events that we have retried, if they were not included in
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
fedapi "github.com/matrix-org/dendrite/federationapi/api"
|
||||
"github.com/matrix-org/dendrite/federationapi/statistics"
|
||||
"github.com/matrix-org/dendrite/federationapi/storage"
|
||||
"github.com/matrix-org/dendrite/federationapi/storage/shared"
|
||||
|
|
@ -49,8 +50,8 @@ type destinationQueue struct {
|
|||
db storage.Database
|
||||
process *process.ProcessContext
|
||||
signing *SigningInfo
|
||||
rsAPI api.RoomserverInternalAPI
|
||||
client *gomatrixserverlib.FederationClient // federation client
|
||||
rsAPI api.FederationRoomserverAPI
|
||||
client fedapi.FederationClient // federation client
|
||||
origin gomatrixserverlib.ServerName // origin of requests
|
||||
destination gomatrixserverlib.ServerName // destination of requests
|
||||
running atomic.Bool // is the queue worker running?
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ import (
|
|||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
|
||||
fedapi "github.com/matrix-org/dendrite/federationapi/api"
|
||||
"github.com/matrix-org/dendrite/federationapi/statistics"
|
||||
"github.com/matrix-org/dendrite/federationapi/storage"
|
||||
"github.com/matrix-org/dendrite/federationapi/storage/shared"
|
||||
|
|
@ -39,9 +40,9 @@ type OutgoingQueues struct {
|
|||
db storage.Database
|
||||
process *process.ProcessContext
|
||||
disabled bool
|
||||
rsAPI api.RoomserverInternalAPI
|
||||
rsAPI api.FederationRoomserverAPI
|
||||
origin gomatrixserverlib.ServerName
|
||||
client *gomatrixserverlib.FederationClient
|
||||
client fedapi.FederationClient
|
||||
statistics *statistics.Statistics
|
||||
signing *SigningInfo
|
||||
queuesMutex sync.Mutex // protects the below
|
||||
|
|
@ -85,8 +86,8 @@ func NewOutgoingQueues(
|
|||
process *process.ProcessContext,
|
||||
disabled bool,
|
||||
origin gomatrixserverlib.ServerName,
|
||||
client *gomatrixserverlib.FederationClient,
|
||||
rsAPI api.RoomserverInternalAPI,
|
||||
client fedapi.FederationClient,
|
||||
rsAPI api.FederationRoomserverAPI,
|
||||
statistics *statistics.Statistics,
|
||||
signing *SigningInfo,
|
||||
) *OutgoingQueues {
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ import (
|
|||
// RoomAliasToID converts the queried alias into a room ID and returns it
|
||||
func RoomAliasToID(
|
||||
httpReq *http.Request,
|
||||
federation *gomatrixserverlib.FederationClient,
|
||||
federation federationAPI.FederationClient,
|
||||
cfg *config.FederationAPI,
|
||||
rsAPI roomserverAPI.FederationRoomserverAPI,
|
||||
senderAPI federationAPI.FederationInternalAPI,
|
||||
|
|
|
|||
|
|
@ -54,7 +54,7 @@ func Setup(
|
|||
rsAPI roomserverAPI.FederationRoomserverAPI,
|
||||
fsAPI *fedInternal.FederationInternalAPI,
|
||||
keys gomatrixserverlib.JSONVerifier,
|
||||
federation *gomatrixserverlib.FederationClient,
|
||||
federation federationAPI.FederationClient,
|
||||
userAPI userapi.FederationUserAPI,
|
||||
keyAPI keyserverAPI.FederationKeyAPI,
|
||||
mscCfg *config.MSCs,
|
||||
|
|
|
|||
|
|
@ -85,7 +85,7 @@ func Send(
|
|||
rsAPI api.FederationRoomserverAPI,
|
||||
keyAPI keyapi.FederationKeyAPI,
|
||||
keys gomatrixserverlib.JSONVerifier,
|
||||
federation *gomatrixserverlib.FederationClient,
|
||||
federation federationAPI.FederationClient,
|
||||
mu *internal.MutexByRoom,
|
||||
servers federationAPI.ServersInRoomProvider,
|
||||
producer *producers.SyncAPIProducer,
|
||||
|
|
|
|||
|
|
@ -8,8 +8,8 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/test"
|
||||
"github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/dendrite/test"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ import (
|
|||
|
||||
"github.com/matrix-org/dendrite/clientapi/httputil"
|
||||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||
federationAPI "github.com/matrix-org/dendrite/federationapi/api"
|
||||
"github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||
|
|
@ -57,7 +58,7 @@ var (
|
|||
func CreateInvitesFrom3PIDInvites(
|
||||
req *http.Request, rsAPI api.FederationRoomserverAPI,
|
||||
cfg *config.FederationAPI,
|
||||
federation *gomatrixserverlib.FederationClient,
|
||||
federation federationAPI.FederationClient,
|
||||
userAPI userapi.FederationUserAPI,
|
||||
) util.JSONResponse {
|
||||
var body invites
|
||||
|
|
@ -107,7 +108,7 @@ func ExchangeThirdPartyInvite(
|
|||
roomID string,
|
||||
rsAPI api.FederationRoomserverAPI,
|
||||
cfg *config.FederationAPI,
|
||||
federation *gomatrixserverlib.FederationClient,
|
||||
federation federationAPI.FederationClient,
|
||||
) util.JSONResponse {
|
||||
var builder gomatrixserverlib.EventBuilder
|
||||
if err := json.Unmarshal(request.Content(), &builder); err != nil {
|
||||
|
|
@ -165,7 +166,12 @@ func ExchangeThirdPartyInvite(
|
|||
|
||||
// Ask the requesting server to sign the newly created event so we know it
|
||||
// acknowledged it
|
||||
signedEvent, err := federation.SendInvite(httpReq.Context(), request.Origin(), event)
|
||||
inviteReq, err := gomatrixserverlib.NewInviteV2Request(event.Headered(verRes.RoomVersion), nil)
|
||||
if err != nil {
|
||||
util.GetLogger(httpReq.Context()).WithError(err).Error("failed to make invite v2 request")
|
||||
return jsonerror.InternalServerError()
|
||||
}
|
||||
signedEvent, err := federation.SendInviteV2(httpReq.Context(), request.Origin(), inviteReq)
|
||||
if err != nil {
|
||||
util.GetLogger(httpReq.Context()).WithError(err).Error("federation.SendInvite failed")
|
||||
return jsonerror.InternalServerError()
|
||||
|
|
@ -205,7 +211,7 @@ func ExchangeThirdPartyInvite(
|
|||
func createInviteFrom3PIDInvite(
|
||||
ctx context.Context, rsAPI api.FederationRoomserverAPI,
|
||||
cfg *config.FederationAPI,
|
||||
inv invite, federation *gomatrixserverlib.FederationClient,
|
||||
inv invite, federation federationAPI.FederationClient,
|
||||
userAPI userapi.FederationUserAPI,
|
||||
) (*gomatrixserverlib.Event, error) {
|
||||
verReq := api.QueryRoomVersionForRoomRequest{RoomID: inv.RoomID}
|
||||
|
|
@ -335,7 +341,7 @@ func buildMembershipEvent(
|
|||
// them responded with an error.
|
||||
func sendToRemoteServer(
|
||||
ctx context.Context, inv invite,
|
||||
federation *gomatrixserverlib.FederationClient, _ *config.FederationAPI,
|
||||
federation federationAPI.FederationClient, _ *config.FederationAPI,
|
||||
builder gomatrixserverlib.EventBuilder,
|
||||
) (err error) {
|
||||
remoteServers := make([]gomatrixserverlib.ServerName, 2)
|
||||
|
|
|
|||
|
|
@ -26,13 +26,12 @@ import (
|
|||
type Database interface {
|
||||
gomatrixserverlib.KeyDatabase
|
||||
|
||||
UpdateRoom(ctx context.Context, roomID, oldEventID, newEventID string, addHosts []types.JoinedHost, removeHosts []string) (joinedHosts []types.JoinedHost, err error)
|
||||
UpdateRoom(ctx context.Context, roomID string, addHosts []types.JoinedHost, removeHosts []string, purgeRoomFirst bool) (joinedHosts []types.JoinedHost, err error)
|
||||
|
||||
GetJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error)
|
||||
GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
|
||||
// GetJoinedHostsForRooms returns the complete set of servers in the rooms given.
|
||||
GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf bool) ([]gomatrixserverlib.ServerName, error)
|
||||
PurgeRoomState(ctx context.Context, roomID string) error
|
||||
|
||||
StoreJSON(ctx context.Context, js string) (*shared.Receipt, error)
|
||||
|
||||
|
|
|
|||
|
|
@ -63,11 +63,21 @@ func (r *Receipt) String() string {
|
|||
// this isn't a duplicate message.
|
||||
func (d *Database) UpdateRoom(
|
||||
ctx context.Context,
|
||||
roomID, oldEventID, newEventID string,
|
||||
roomID string,
|
||||
addHosts []types.JoinedHost,
|
||||
removeHosts []string,
|
||||
purgeRoomFirst bool,
|
||||
) (joinedHosts []types.JoinedHost, err error) {
|
||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
if purgeRoomFirst {
|
||||
// If the event is a create event then we'll delete all of the existing
|
||||
// data for the room. The only reason that a create event would be replayed
|
||||
// to us in this way is if we're about to receive the entire room state.
|
||||
if err = d.FederationJoinedHosts.DeleteJoinedHostsForRoom(ctx, txn, roomID); err != nil {
|
||||
return fmt.Errorf("d.FederationJoinedHosts.DeleteJoinedHosts: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
joinedHosts, err = d.FederationJoinedHosts.SelectJoinedHostsWithTx(ctx, txn, roomID)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -138,20 +148,6 @@ func (d *Database) StoreJSON(
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (d *Database) PurgeRoomState(
|
||||
ctx context.Context, roomID string,
|
||||
) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
// If the event is a create event then we'll delete all of the existing
|
||||
// data for the room. The only reason that a create event would be replayed
|
||||
// to us in this way is if we're about to receive the entire room state.
|
||||
if err := d.FederationJoinedHosts.DeleteJoinedHostsForRoom(ctx, txn, roomID); err != nil {
|
||||
return fmt.Errorf("d.FederationJoinedHosts.DeleteJoinedHosts: %w", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (d *Database) AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
return d.FederationBlacklist.InsertBlacklist(context.TODO(), txn, serverName)
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/test"
|
||||
"github.com/matrix-org/dendrite/test"
|
||||
)
|
||||
|
||||
func TestEDUCache(t *testing.T) {
|
||||
|
|
|
|||
|
|
@ -1,158 +0,0 @@
|
|||
// Copyright 2017 Vector Creations Ltd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package test
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
// Request contains the information necessary to issue a request and test its result
|
||||
type Request struct {
|
||||
Req *http.Request
|
||||
WantedBody string
|
||||
WantedStatusCode int
|
||||
LastErr *LastRequestErr
|
||||
}
|
||||
|
||||
// LastRequestErr is a synchronised error wrapper
|
||||
// Useful for obtaining the last error from a set of requests
|
||||
type LastRequestErr struct {
|
||||
sync.Mutex
|
||||
Err error
|
||||
}
|
||||
|
||||
// Set sets the error
|
||||
func (r *LastRequestErr) Set(err error) {
|
||||
r.Lock()
|
||||
defer r.Unlock()
|
||||
r.Err = err
|
||||
}
|
||||
|
||||
// Get gets the error
|
||||
func (r *LastRequestErr) Get() error {
|
||||
r.Lock()
|
||||
defer r.Unlock()
|
||||
return r.Err
|
||||
}
|
||||
|
||||
// CanonicalJSONInput canonicalises a slice of JSON strings
|
||||
// Useful for test input
|
||||
func CanonicalJSONInput(jsonData []string) []string {
|
||||
for i := range jsonData {
|
||||
jsonBytes, err := gomatrixserverlib.CanonicalJSON([]byte(jsonData[i]))
|
||||
if err != nil && err != io.EOF {
|
||||
panic(err)
|
||||
}
|
||||
jsonData[i] = string(jsonBytes)
|
||||
}
|
||||
return jsonData
|
||||
}
|
||||
|
||||
// Do issues a request and checks the status code and body of the response
|
||||
func (r *Request) Do() (err error) {
|
||||
client := &http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
res, err := client.Do(r.Req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer (func() { err = res.Body.Close() })()
|
||||
|
||||
if res.StatusCode != r.WantedStatusCode {
|
||||
return fmt.Errorf("incorrect status code. Expected: %d Got: %d", r.WantedStatusCode, res.StatusCode)
|
||||
}
|
||||
|
||||
if r.WantedBody != "" {
|
||||
resBytes, err := ioutil.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
jsonBytes, err := gomatrixserverlib.CanonicalJSON(resBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if string(jsonBytes) != r.WantedBody {
|
||||
return fmt.Errorf("returned wrong bytes. Expected:\n%s\n\nGot:\n%s", r.WantedBody, string(jsonBytes))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DoUntilSuccess blocks and repeats the same request until the response returns the desired status code and body.
|
||||
// It then closes the given channel and returns.
|
||||
func (r *Request) DoUntilSuccess(done chan error) {
|
||||
r.LastErr = &LastRequestErr{}
|
||||
for {
|
||||
if err := r.Do(); err != nil {
|
||||
r.LastErr.Set(err)
|
||||
time.Sleep(1 * time.Second) // don't tightloop
|
||||
continue
|
||||
}
|
||||
close(done)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Run repeatedly issues a request until success, error or a timeout is reached
|
||||
func (r *Request) Run(label string, timeout time.Duration, serverCmdChan chan error) {
|
||||
fmt.Printf("==TESTING== %v (timeout: %v)\n", label, timeout)
|
||||
done := make(chan error, 1)
|
||||
|
||||
// We need to wait for the server to:
|
||||
// - have connected to the database
|
||||
// - have created the tables
|
||||
// - be listening on the given port
|
||||
go r.DoUntilSuccess(done)
|
||||
|
||||
// wait for one of:
|
||||
// - the test to pass (done channel is closed)
|
||||
// - the server to exit with an error (error sent on serverCmdChan)
|
||||
// - our test timeout to expire
|
||||
// We don't need to clean up since the main() function handles that in the event we panic
|
||||
select {
|
||||
case <-time.After(timeout):
|
||||
fmt.Printf("==TESTING== %v TIMEOUT\n", label)
|
||||
if reqErr := r.LastErr.Get(); reqErr != nil {
|
||||
fmt.Println("Last /sync request error:")
|
||||
fmt.Println(reqErr)
|
||||
}
|
||||
panic(fmt.Sprintf("%v server timed out", label))
|
||||
case err := <-serverCmdChan:
|
||||
if err != nil {
|
||||
fmt.Println("=============================================================================================")
|
||||
fmt.Printf("%v server failed to run. If failing with 'pq: password authentication failed for user' try:", label)
|
||||
fmt.Println(" export PGHOST=/var/run/postgresql")
|
||||
fmt.Println("=============================================================================================")
|
||||
panic(err)
|
||||
}
|
||||
case <-done:
|
||||
fmt.Printf("==TESTING== %v PASSED\n", label)
|
||||
}
|
||||
}
|
||||
|
|
@ -1,76 +0,0 @@
|
|||
// Copyright 2017 Vector Creations Ltd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package test
|
||||
|
||||
import (
|
||||
"io"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// KafkaExecutor executes kafka scripts.
|
||||
type KafkaExecutor struct {
|
||||
// The location of Zookeeper. Typically this is `localhost:2181`.
|
||||
ZookeeperURI string
|
||||
// The directory where Kafka is installed to. Used to locate kafka scripts.
|
||||
KafkaDirectory string
|
||||
// The location of the Kafka logs. Typically this is `localhost:9092`.
|
||||
KafkaURI string
|
||||
// Where stdout and stderr should be written to. Typically this is `os.Stderr`.
|
||||
OutputWriter io.Writer
|
||||
}
|
||||
|
||||
// CreateTopic creates a new kafka topic. This is created with a single partition.
|
||||
func (e *KafkaExecutor) CreateTopic(topic string) error {
|
||||
cmd := exec.Command(
|
||||
filepath.Join(e.KafkaDirectory, "bin", "kafka-topics.sh"),
|
||||
"--create",
|
||||
"--zookeeper", e.ZookeeperURI,
|
||||
"--replication-factor", "1",
|
||||
"--partitions", "1",
|
||||
"--topic", topic,
|
||||
)
|
||||
cmd.Stdout = e.OutputWriter
|
||||
cmd.Stderr = e.OutputWriter
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
// WriteToTopic writes data to a kafka topic.
|
||||
func (e *KafkaExecutor) WriteToTopic(topic string, data []string) error {
|
||||
cmd := exec.Command(
|
||||
filepath.Join(e.KafkaDirectory, "bin", "kafka-console-producer.sh"),
|
||||
"--broker-list", e.KafkaURI,
|
||||
"--topic", topic,
|
||||
)
|
||||
cmd.Stdout = e.OutputWriter
|
||||
cmd.Stderr = e.OutputWriter
|
||||
cmd.Stdin = strings.NewReader(strings.Join(data, "\n"))
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
// DeleteTopic deletes a given kafka topic if it exists.
|
||||
func (e *KafkaExecutor) DeleteTopic(topic string) error {
|
||||
cmd := exec.Command(
|
||||
filepath.Join(e.KafkaDirectory, "bin", "kafka-topics.sh"),
|
||||
"--delete",
|
||||
"--if-exists",
|
||||
"--zookeeper", e.ZookeeperURI,
|
||||
"--topic", topic,
|
||||
)
|
||||
cmd.Stderr = e.OutputWriter
|
||||
cmd.Stdout = e.OutputWriter
|
||||
return cmd.Run()
|
||||
}
|
||||
|
|
@ -1,152 +0,0 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
)
|
||||
|
||||
// Defaulting allows assignment of string variables with a fallback default value
|
||||
// Useful for use with os.Getenv() for example
|
||||
func Defaulting(value, defaultValue string) string {
|
||||
if value == "" {
|
||||
value = defaultValue
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
// CreateDatabase creates a new database, dropping it first if it exists
|
||||
func CreateDatabase(command string, args []string, database string) error {
|
||||
cmd := exec.Command(command, args...)
|
||||
cmd.Stdin = strings.NewReader(
|
||||
fmt.Sprintf("DROP DATABASE IF EXISTS %s; CREATE DATABASE %s;", database, database),
|
||||
)
|
||||
// Send stdout and stderr to our stderr so that we see error messages from
|
||||
// the psql process
|
||||
cmd.Stdout = os.Stderr
|
||||
cmd.Stderr = os.Stderr
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
// CreateBackgroundCommand creates an executable command
|
||||
// The Cmd being executed is returned. A channel is also returned,
|
||||
// which will have any termination errors sent down it, followed immediately by the channel being closed.
|
||||
func CreateBackgroundCommand(command string, args []string) (*exec.Cmd, chan error) {
|
||||
cmd := exec.Command(command, args...)
|
||||
cmd.Stderr = os.Stderr
|
||||
cmd.Stdout = os.Stderr
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
panic("failed to start server: " + err.Error())
|
||||
}
|
||||
cmdChan := make(chan error, 1)
|
||||
go func() {
|
||||
cmdChan <- cmd.Wait()
|
||||
close(cmdChan)
|
||||
}()
|
||||
return cmd, cmdChan
|
||||
}
|
||||
|
||||
// InitDatabase creates the database and config file needed for the server to run
|
||||
func InitDatabase(postgresDatabase, postgresContainerName string, databases []string) {
|
||||
if len(databases) > 0 {
|
||||
var dbCmd string
|
||||
var dbArgs []string
|
||||
if postgresContainerName == "" {
|
||||
dbCmd = "psql"
|
||||
dbArgs = []string{postgresDatabase}
|
||||
} else {
|
||||
dbCmd = "docker"
|
||||
dbArgs = []string{
|
||||
"exec", "-i", postgresContainerName, "psql", "-U", "postgres", postgresDatabase,
|
||||
}
|
||||
}
|
||||
for _, database := range databases {
|
||||
if err := CreateDatabase(dbCmd, dbArgs, database); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// StartProxy creates a reverse proxy
|
||||
func StartProxy(bindAddr string, cfg *config.Dendrite) (*exec.Cmd, chan error) {
|
||||
proxyArgs := []string{
|
||||
"--bind-address", bindAddr,
|
||||
"--sync-api-server-url", "http://" + string(cfg.SyncAPI.InternalAPI.Connect),
|
||||
"--client-api-server-url", "http://" + string(cfg.ClientAPI.InternalAPI.Connect),
|
||||
"--media-api-server-url", "http://" + string(cfg.MediaAPI.InternalAPI.Connect),
|
||||
"--tls-cert", "server.crt",
|
||||
"--tls-key", "server.key",
|
||||
}
|
||||
return CreateBackgroundCommand(
|
||||
filepath.Join(filepath.Dir(os.Args[0]), "client-api-proxy"),
|
||||
proxyArgs,
|
||||
)
|
||||
}
|
||||
|
||||
// ListenAndServe will listen on a random high-numbered port and attach the given router.
|
||||
// Returns the base URL to send requests to. Call `cancel` to shutdown the server, which will block until it has closed.
|
||||
func ListenAndServe(t *testing.T, router http.Handler, useTLS bool) (apiURL string, cancel func()) {
|
||||
listener, err := net.Listen("tcp", ":0")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to listen: %s", err)
|
||||
}
|
||||
port := listener.Addr().(*net.TCPAddr).Port
|
||||
srv := http.Server{}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
srv.Handler = router
|
||||
var err error
|
||||
if useTLS {
|
||||
certFile := filepath.Join(os.TempDir(), "dendrite.cert")
|
||||
keyFile := filepath.Join(os.TempDir(), "dendrite.key")
|
||||
err = NewTLSKey(keyFile, certFile)
|
||||
if err != nil {
|
||||
t.Logf("failed to generate tls key/cert: %s", err)
|
||||
return
|
||||
}
|
||||
err = srv.ServeTLS(listener, certFile, keyFile)
|
||||
} else {
|
||||
err = srv.Serve(listener)
|
||||
}
|
||||
if err != nil && err != http.ErrServerClosed {
|
||||
t.Logf("Listen failed: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
secure := ""
|
||||
if useTLS {
|
||||
secure = "s"
|
||||
}
|
||||
return fmt.Sprintf("http%s://localhost:%d", secure, port), func() {
|
||||
_ = srv.Shutdown(context.Background())
|
||||
wg.Wait()
|
||||
}
|
||||
}
|
||||
|
|
@ -84,7 +84,7 @@ type DeviceListUpdater struct {
|
|||
db DeviceListUpdaterDatabase
|
||||
api DeviceListUpdaterAPI
|
||||
producer KeyChangeProducer
|
||||
fedClient fedsenderapi.FederationClient
|
||||
fedClient fedsenderapi.KeyserverFederationAPI
|
||||
workerChans []chan gomatrixserverlib.ServerName
|
||||
|
||||
// When device lists are stale for a user, they get inserted into this map with a channel which `Update` will
|
||||
|
|
@ -127,7 +127,7 @@ type KeyChangeProducer interface {
|
|||
// NewDeviceListUpdater creates a new updater which fetches fresh device lists when they go stale.
|
||||
func NewDeviceListUpdater(
|
||||
db DeviceListUpdaterDatabase, api DeviceListUpdaterAPI, producer KeyChangeProducer,
|
||||
fedClient fedsenderapi.FederationClient, numWorkers int,
|
||||
fedClient fedsenderapi.KeyserverFederationAPI, numWorkers int,
|
||||
) *DeviceListUpdater {
|
||||
return &DeviceListUpdater{
|
||||
userIDToMutex: make(map[string]*sync.Mutex),
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ import (
|
|||
type KeyInternalAPI struct {
|
||||
DB storage.Database
|
||||
ThisServer gomatrixserverlib.ServerName
|
||||
FedClient fedsenderapi.FederationClient
|
||||
FedClient fedsenderapi.KeyserverFederationAPI
|
||||
UserAPI userapi.KeyserverUserAPI
|
||||
Producer *producers.KeyChange
|
||||
Updater *DeviceListUpdater
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ func AddInternalRoutes(router *mux.Router, intAPI api.KeyInternalAPI) {
|
|||
// NewInternalAPI returns a concerete implementation of the internal API. Callers
|
||||
// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes.
|
||||
func NewInternalAPI(
|
||||
base *base.BaseDendrite, cfg *config.KeyServer, fedClient fedsenderapi.FederationClient,
|
||||
base *base.BaseDendrite, cfg *config.KeyServer, fedClient fedsenderapi.KeyserverFederationAPI,
|
||||
) api.KeyInternalAPI {
|
||||
js, _ := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream)
|
||||
|
||||
|
|
|
|||
|
|
@ -183,6 +183,7 @@ type FederationRoomserverAPI interface {
|
|||
QueryMissingEvents(ctx context.Context, req *QueryMissingEventsRequest, res *QueryMissingEventsResponse) error
|
||||
// Query whether a server is allowed to see an event
|
||||
QueryServerAllowedToSeeEvent(ctx context.Context, req *QueryServerAllowedToSeeEventRequest, res *QueryServerAllowedToSeeEventResponse) error
|
||||
QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error
|
||||
PerformInboundPeek(ctx context.Context, req *PerformInboundPeekRequest, res *PerformInboundPeekResponse) error
|
||||
PerformInvite(ctx context.Context, req *PerformInviteRequest, res *PerformInviteResponse) error
|
||||
// Query a given amount (or less) of events prior to a given set of events.
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/roomserver/storage"
|
||||
"github.com/matrix-org/dendrite/setup/base"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/test"
|
||||
"github.com/matrix-org/dendrite/test/testrig"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/nats-io/nats.go"
|
||||
)
|
||||
|
|
@ -22,7 +22,7 @@ var jc *nats.Conn
|
|||
|
||||
func TestMain(m *testing.M) {
|
||||
var b *base.BaseDendrite
|
||||
b, js, jc = test.Base(nil)
|
||||
b, js, jc = testrig.Base(nil)
|
||||
code := m.Run()
|
||||
b.ShutdownDendrite()
|
||||
b.WaitForComponentsToFinish()
|
||||
|
|
|
|||
|
|
@ -19,8 +19,8 @@ import (
|
|||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/test"
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
"github.com/matrix-org/dendrite/test"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -264,11 +264,11 @@ func (s *eventStatements) BulkSelectStateEventByNID(
|
|||
ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
|
||||
stateKeyTuples []types.StateKeyTuple,
|
||||
) ([]types.StateEntry, error) {
|
||||
tuples := stateKeyTupleSorter(stateKeyTuples)
|
||||
tuples := types.StateKeyTupleSorter(stateKeyTuples)
|
||||
sort.Sort(tuples)
|
||||
eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays()
|
||||
eventTypeNIDArray, eventStateKeyNIDArray := tuples.TypesAndStateKeysAsArrays()
|
||||
stmt := sqlutil.TxStmt(txn, s.bulkSelectStateEventByNIDStmt)
|
||||
rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs), eventTypeNIDArray, eventStateKeyNIDArray)
|
||||
rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs), pq.Int64Array(eventTypeNIDArray), pq.Int64Array(eventStateKeyNIDArray))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -61,12 +61,12 @@ type roomAliasesStatements struct {
|
|||
deleteRoomAliasStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func createRoomAliasesTable(db *sql.DB) error {
|
||||
func CreateRoomAliasesTable(db *sql.DB) error {
|
||||
_, err := db.Exec(roomAliasesSchema)
|
||||
return err
|
||||
}
|
||||
|
||||
func prepareRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) {
|
||||
func PrepareRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) {
|
||||
s := &roomAliasesStatements{}
|
||||
|
||||
return s, sqlutil.StatementList{
|
||||
|
|
@ -108,8 +108,8 @@ func (s *roomAliasesStatements) SelectAliasesFromRoomID(
|
|||
defer internal.CloseAndLogIfError(ctx, rows, "selectAliasesFromRoomID: rows.close() failed")
|
||||
|
||||
var aliases []string
|
||||
for rows.Next() {
|
||||
var alias string
|
||||
for rows.Next() {
|
||||
if err = rows.Scan(&alias); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -95,12 +95,12 @@ type roomStatements struct {
|
|||
bulkSelectRoomNIDsStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func createRoomsTable(db *sql.DB) error {
|
||||
func CreateRoomsTable(db *sql.DB) error {
|
||||
_, err := db.Exec(roomsSchema)
|
||||
return err
|
||||
}
|
||||
|
||||
func prepareRoomsTable(db *sql.DB) (tables.Rooms, error) {
|
||||
func PrepareRoomsTable(db *sql.DB) (tables.Rooms, error) {
|
||||
s := &roomStatements{}
|
||||
|
||||
return s, sqlutil.StatementList{
|
||||
|
|
@ -117,7 +117,7 @@ func prepareRoomsTable(db *sql.DB) (tables.Rooms, error) {
|
|||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *roomStatements) SelectRoomIDs(ctx context.Context, txn *sql.Tx) ([]string, error) {
|
||||
func (s *roomStatements) SelectRoomIDsWithEvents(ctx context.Context, txn *sql.Tx) ([]string, error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectRoomIDsStmt)
|
||||
rows, err := stmt.QueryContext(ctx)
|
||||
if err != nil {
|
||||
|
|
@ -125,8 +125,8 @@ func (s *roomStatements) SelectRoomIDs(ctx context.Context, txn *sql.Tx) ([]stri
|
|||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomIDsStmt: rows.close() failed")
|
||||
var roomIDs []string
|
||||
for rows.Next() {
|
||||
var roomID string
|
||||
for rows.Next() {
|
||||
if err = rows.Scan(&roomID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -231,9 +231,9 @@ func (s *roomStatements) SelectRoomVersionsForRoomNIDs(
|
|||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomVersionsForRoomNIDsStmt: rows.close() failed")
|
||||
result := make(map[types.RoomNID]gomatrixserverlib.RoomVersion)
|
||||
for rows.Next() {
|
||||
var roomNID types.RoomNID
|
||||
var roomVersion gomatrixserverlib.RoomVersion
|
||||
for rows.Next() {
|
||||
if err = rows.Scan(&roomNID, &roomVersion); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -254,8 +254,8 @@ func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, txn *sql.Tx, roo
|
|||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomIDsStmt: rows.close() failed")
|
||||
var roomIDs []string
|
||||
for rows.Next() {
|
||||
var roomID string
|
||||
for rows.Next() {
|
||||
if err = rows.Scan(&roomID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -276,8 +276,8 @@ func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, txn *sql.Tx, ro
|
|||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomNIDsStmt: rows.close() failed")
|
||||
var roomNIDs []types.RoomNID
|
||||
for rows.Next() {
|
||||
var roomNID types.RoomNID
|
||||
for rows.Next() {
|
||||
if err = rows.Scan(&roomNID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,7 +19,6 @@ import (
|
|||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"sort"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
|
|
@ -71,12 +70,12 @@ type stateBlockStatements struct {
|
|||
bulkSelectStateBlockEntriesStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func createStateBlockTable(db *sql.DB) error {
|
||||
func CreateStateBlockTable(db *sql.DB) error {
|
||||
_, err := db.Exec(stateDataSchema)
|
||||
return err
|
||||
}
|
||||
|
||||
func prepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) {
|
||||
func PrepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) {
|
||||
s := &stateBlockStatements{}
|
||||
|
||||
return s, sqlutil.StatementList{
|
||||
|
|
@ -90,9 +89,9 @@ func (s *stateBlockStatements) BulkInsertStateData(
|
|||
entries types.StateEntries,
|
||||
) (id types.StateBlockNID, err error) {
|
||||
entries = entries[:util.SortAndUnique(entries)]
|
||||
var nids types.EventNIDs
|
||||
for _, e := range entries {
|
||||
nids = append(nids, e.EventNID)
|
||||
nids := make(types.EventNIDs, entries.Len())
|
||||
for i := range entries {
|
||||
nids[i] = entries[i].EventNID
|
||||
}
|
||||
stmt := sqlutil.TxStmt(txn, s.insertStateDataStmt)
|
||||
err = stmt.QueryRowContext(
|
||||
|
|
@ -113,15 +112,15 @@ func (s *stateBlockStatements) BulkSelectStateBlockEntries(
|
|||
|
||||
results := make([][]types.EventNID, len(stateBlockNIDs))
|
||||
i := 0
|
||||
for ; rows.Next(); i++ {
|
||||
var stateBlockNID types.StateBlockNID
|
||||
var result pq.Int64Array
|
||||
for ; rows.Next(); i++ {
|
||||
if err = rows.Scan(&stateBlockNID, &result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r := []types.EventNID{}
|
||||
for _, e := range result {
|
||||
r = append(r, types.EventNID(e))
|
||||
r := make([]types.EventNID, len(result))
|
||||
for x := range result {
|
||||
r[x] = types.EventNID(result[x])
|
||||
}
|
||||
results[i] = r
|
||||
}
|
||||
|
|
@ -141,35 +140,3 @@ func stateBlockNIDsAsArray(stateBlockNIDs []types.StateBlockNID) pq.Int64Array {
|
|||
}
|
||||
return pq.Int64Array(nids)
|
||||
}
|
||||
|
||||
type stateKeyTupleSorter []types.StateKeyTuple
|
||||
|
||||
func (s stateKeyTupleSorter) Len() int { return len(s) }
|
||||
func (s stateKeyTupleSorter) Less(i, j int) bool { return s[i].LessThan(s[j]) }
|
||||
func (s stateKeyTupleSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
|
||||
|
||||
// Check whether a tuple is in the list. Assumes that the list is sorted.
|
||||
func (s stateKeyTupleSorter) contains(value types.StateKeyTuple) bool {
|
||||
i := sort.Search(len(s), func(i int) bool { return !s[i].LessThan(value) })
|
||||
return i < len(s) && s[i] == value
|
||||
}
|
||||
|
||||
// List the unique eventTypeNIDs and eventStateKeyNIDs.
|
||||
// Assumes that the list is sorted.
|
||||
func (s stateKeyTupleSorter) typesAndStateKeysAsArrays() (eventTypeNIDs pq.Int64Array, eventStateKeyNIDs pq.Int64Array) {
|
||||
eventTypeNIDs = make(pq.Int64Array, len(s))
|
||||
eventStateKeyNIDs = make(pq.Int64Array, len(s))
|
||||
for i := range s {
|
||||
eventTypeNIDs[i] = int64(s[i].EventTypeNID)
|
||||
eventStateKeyNIDs[i] = int64(s[i].EventStateKeyNID)
|
||||
}
|
||||
eventTypeNIDs = eventTypeNIDs[:util.SortAndUnique(int64Sorter(eventTypeNIDs))]
|
||||
eventStateKeyNIDs = eventStateKeyNIDs[:util.SortAndUnique(int64Sorter(eventStateKeyNIDs))]
|
||||
return
|
||||
}
|
||||
|
||||
type int64Sorter []int64
|
||||
|
||||
func (s int64Sorter) Len() int { return len(s) }
|
||||
func (s int64Sorter) Less(i, j int) bool { return s[i] < s[j] }
|
||||
func (s int64Sorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
|
||||
|
|
|
|||
|
|
@ -1,86 +0,0 @@
|
|||
// Copyright 2017-2018 New Vector Ltd
|
||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"testing"
|
||||
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
)
|
||||
|
||||
func TestStateKeyTupleSorter(t *testing.T) {
|
||||
input := stateKeyTupleSorter{
|
||||
{EventTypeNID: 1, EventStateKeyNID: 2},
|
||||
{EventTypeNID: 1, EventStateKeyNID: 4},
|
||||
{EventTypeNID: 2, EventStateKeyNID: 2},
|
||||
{EventTypeNID: 1, EventStateKeyNID: 1},
|
||||
}
|
||||
want := []types.StateKeyTuple{
|
||||
{EventTypeNID: 1, EventStateKeyNID: 1},
|
||||
{EventTypeNID: 1, EventStateKeyNID: 2},
|
||||
{EventTypeNID: 1, EventStateKeyNID: 4},
|
||||
{EventTypeNID: 2, EventStateKeyNID: 2},
|
||||
}
|
||||
doNotWant := []types.StateKeyTuple{
|
||||
{EventTypeNID: 0, EventStateKeyNID: 0},
|
||||
{EventTypeNID: 1, EventStateKeyNID: 3},
|
||||
{EventTypeNID: 2, EventStateKeyNID: 1},
|
||||
{EventTypeNID: 3, EventStateKeyNID: 1},
|
||||
}
|
||||
wantTypeNIDs := []int64{1, 2}
|
||||
wantStateKeyNIDs := []int64{1, 2, 4}
|
||||
|
||||
// Sort the input and check it's in the right order.
|
||||
sort.Sort(input)
|
||||
gotTypeNIDs, gotStateKeyNIDs := input.typesAndStateKeysAsArrays()
|
||||
|
||||
for i := range want {
|
||||
if input[i] != want[i] {
|
||||
t.Errorf("Wanted %#v at index %d got %#v", want[i], i, input[i])
|
||||
}
|
||||
|
||||
if !input.contains(want[i]) {
|
||||
t.Errorf("Wanted %#v.contains(%#v) to be true but got false", input, want[i])
|
||||
}
|
||||
}
|
||||
|
||||
for i := range doNotWant {
|
||||
if input.contains(doNotWant[i]) {
|
||||
t.Errorf("Wanted %#v.contains(%#v) to be false but got true", input, doNotWant[i])
|
||||
}
|
||||
}
|
||||
|
||||
if len(wantTypeNIDs) != len(gotTypeNIDs) {
|
||||
t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
|
||||
}
|
||||
|
||||
for i := range wantTypeNIDs {
|
||||
if wantTypeNIDs[i] != gotTypeNIDs[i] {
|
||||
t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
|
||||
}
|
||||
}
|
||||
|
||||
if len(wantStateKeyNIDs) != len(gotStateKeyNIDs) {
|
||||
t.Fatalf("Wanted state key NIDs %#v got %#v", wantStateKeyNIDs, gotStateKeyNIDs)
|
||||
}
|
||||
|
||||
for i := range wantStateKeyNIDs {
|
||||
if wantStateKeyNIDs[i] != gotStateKeyNIDs[i] {
|
||||
t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -77,12 +77,12 @@ type stateSnapshotStatements struct {
|
|||
bulkSelectStateBlockNIDsStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func createStateSnapshotTable(db *sql.DB) error {
|
||||
func CreateStateSnapshotTable(db *sql.DB) error {
|
||||
_, err := db.Exec(stateSnapshotSchema)
|
||||
return err
|
||||
}
|
||||
|
||||
func prepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) {
|
||||
func PrepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) {
|
||||
s := &stateSnapshotStatements{}
|
||||
|
||||
return s, sqlutil.StatementList{
|
||||
|
|
@ -95,12 +95,10 @@ func (s *stateSnapshotStatements) InsertState(
|
|||
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, nids types.StateBlockNIDs,
|
||||
) (stateNID types.StateSnapshotNID, err error) {
|
||||
nids = nids[:util.SortAndUnique(nids)]
|
||||
var id int64
|
||||
err = sqlutil.TxStmt(txn, s.insertStateStmt).QueryRowContext(ctx, nids.Hash(), int64(roomNID), stateBlockNIDsAsArray(nids)).Scan(&id)
|
||||
err = sqlutil.TxStmt(txn, s.insertStateStmt).QueryRowContext(ctx, nids.Hash(), int64(roomNID), stateBlockNIDsAsArray(nids)).Scan(&stateNID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
stateNID = types.StateSnapshotNID(id)
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -119,9 +117,9 @@ func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs(
|
|||
defer rows.Close() // nolint: errcheck
|
||||
results := make([]types.StateBlockNIDList, len(stateNIDs))
|
||||
i := 0
|
||||
var stateBlockNIDs pq.Int64Array
|
||||
for ; rows.Next(); i++ {
|
||||
result := &results[i]
|
||||
var stateBlockNIDs pq.Int64Array
|
||||
if err = rows.Scan(&result.StateSnapshotNID, &stateBlockNIDs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -80,19 +80,19 @@ func (d *Database) create(db *sql.DB) error {
|
|||
if err := CreateEventsTable(db); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := createRoomsTable(db); err != nil {
|
||||
if err := CreateRoomsTable(db); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := createStateBlockTable(db); err != nil {
|
||||
if err := CreateStateBlockTable(db); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := createStateSnapshotTable(db); err != nil {
|
||||
if err := CreateStateSnapshotTable(db); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := CreatePrevEventsTable(db); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := createRoomAliasesTable(db); err != nil {
|
||||
if err := CreateRoomAliasesTable(db); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := CreateInvitesTable(db); err != nil {
|
||||
|
|
@ -128,15 +128,15 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rooms, err := prepareRoomsTable(db)
|
||||
rooms, err := PrepareRoomsTable(db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
stateBlock, err := prepareStateBlockTable(db)
|
||||
stateBlock, err := PrepareStateBlockTable(db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
stateSnapshot, err := prepareStateSnapshotTable(db)
|
||||
stateSnapshot, err := PrepareStateSnapshotTable(db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -144,7 +144,7 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
roomAliases, err := prepareRoomAliasesTable(db)
|
||||
roomAliases, err := PrepareRoomAliasesTable(db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1216,7 +1216,7 @@ func (d *Database) GetKnownUsers(ctx context.Context, userID, searchString strin
|
|||
|
||||
// GetKnownRooms returns a list of all rooms we know about.
|
||||
func (d *Database) GetKnownRooms(ctx context.Context) ([]string, error) {
|
||||
return d.RoomsTable.SelectRoomIDs(ctx, nil)
|
||||
return d.RoomsTable.SelectRoomIDsWithEvents(ctx, nil)
|
||||
}
|
||||
|
||||
// ForgetRoom sets a users room to forgotten
|
||||
|
|
|
|||
|
|
@ -247,9 +247,9 @@ func (s *eventStatements) BulkSelectStateEventByNID(
|
|||
ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
|
||||
stateKeyTuples []types.StateKeyTuple,
|
||||
) ([]types.StateEntry, error) {
|
||||
tuples := stateKeyTupleSorter(stateKeyTuples)
|
||||
tuples := types.StateKeyTupleSorter(stateKeyTuples)
|
||||
sort.Sort(tuples)
|
||||
eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays()
|
||||
eventTypeNIDArray, eventStateKeyNIDArray := tuples.TypesAndStateKeysAsArrays()
|
||||
params := make([]interface{}, 0, len(eventNIDs)+len(eventTypeNIDArray)+len(eventStateKeyNIDArray))
|
||||
selectOrig := strings.Replace(bulkSelectStateEventByNIDSQL, "($1)", sqlutil.QueryVariadic(len(eventNIDs)), 1)
|
||||
for _, v := range eventNIDs {
|
||||
|
|
|
|||
|
|
@ -63,12 +63,12 @@ type roomAliasesStatements struct {
|
|||
deleteRoomAliasStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func createRoomAliasesTable(db *sql.DB) error {
|
||||
func CreateRoomAliasesTable(db *sql.DB) error {
|
||||
_, err := db.Exec(roomAliasesSchema)
|
||||
return err
|
||||
}
|
||||
|
||||
func prepareRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) {
|
||||
func PrepareRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) {
|
||||
s := &roomAliasesStatements{
|
||||
db: db,
|
||||
}
|
||||
|
|
@ -113,8 +113,8 @@ func (s *roomAliasesStatements) SelectAliasesFromRoomID(
|
|||
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectAliasesFromRoomID: rows.close() failed")
|
||||
|
||||
for rows.Next() {
|
||||
var alias string
|
||||
for rows.Next() {
|
||||
if err = rows.Scan(&alias); err != nil {
|
||||
return
|
||||
}
|
||||
|
|
|
|||
|
|
@ -86,12 +86,12 @@ type roomStatements struct {
|
|||
selectRoomIDsStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func createRoomsTable(db *sql.DB) error {
|
||||
func CreateRoomsTable(db *sql.DB) error {
|
||||
_, err := db.Exec(roomsSchema)
|
||||
return err
|
||||
}
|
||||
|
||||
func prepareRoomsTable(db *sql.DB) (tables.Rooms, error) {
|
||||
func PrepareRoomsTable(db *sql.DB) (tables.Rooms, error) {
|
||||
s := &roomStatements{
|
||||
db: db,
|
||||
}
|
||||
|
|
@ -108,7 +108,7 @@ func prepareRoomsTable(db *sql.DB) (tables.Rooms, error) {
|
|||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *roomStatements) SelectRoomIDs(ctx context.Context, txn *sql.Tx) ([]string, error) {
|
||||
func (s *roomStatements) SelectRoomIDsWithEvents(ctx context.Context, txn *sql.Tx) ([]string, error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectRoomIDsStmt)
|
||||
rows, err := stmt.QueryContext(ctx)
|
||||
if err != nil {
|
||||
|
|
@ -116,8 +116,8 @@ func (s *roomStatements) SelectRoomIDs(ctx context.Context, txn *sql.Tx) ([]stri
|
|||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomIDsStmt: rows.close() failed")
|
||||
var roomIDs []string
|
||||
for rows.Next() {
|
||||
var roomID string
|
||||
for rows.Next() {
|
||||
if err = rows.Scan(&roomID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -241,9 +241,9 @@ func (s *roomStatements) SelectRoomVersionsForRoomNIDs(
|
|||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomVersionsForRoomNIDsStmt: rows.close() failed")
|
||||
result := make(map[types.RoomNID]gomatrixserverlib.RoomVersion)
|
||||
for rows.Next() {
|
||||
var roomNID types.RoomNID
|
||||
var roomVersion gomatrixserverlib.RoomVersion
|
||||
for rows.Next() {
|
||||
if err = rows.Scan(&roomNID, &roomVersion); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -270,8 +270,8 @@ func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, txn *sql.Tx, roo
|
|||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomIDsStmt: rows.close() failed")
|
||||
var roomIDs []string
|
||||
for rows.Next() {
|
||||
var roomID string
|
||||
for rows.Next() {
|
||||
if err = rows.Scan(&roomID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -298,8 +298,8 @@ func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, txn *sql.Tx, ro
|
|||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomNIDsStmt: rows.close() failed")
|
||||
var roomNIDs []types.RoomNID
|
||||
for rows.Next() {
|
||||
var roomNID types.RoomNID
|
||||
for rows.Next() {
|
||||
if err = rows.Scan(&roomNID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -20,7 +20,6 @@ import (
|
|||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
|
|
@ -64,12 +63,12 @@ type stateBlockStatements struct {
|
|||
bulkSelectStateBlockEntriesStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func createStateBlockTable(db *sql.DB) error {
|
||||
func CreateStateBlockTable(db *sql.DB) error {
|
||||
_, err := db.Exec(stateDataSchema)
|
||||
return err
|
||||
}
|
||||
|
||||
func prepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) {
|
||||
func PrepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) {
|
||||
s := &stateBlockStatements{
|
||||
db: db,
|
||||
}
|
||||
|
|
@ -85,9 +84,9 @@ func (s *stateBlockStatements) BulkInsertStateData(
|
|||
entries types.StateEntries,
|
||||
) (id types.StateBlockNID, err error) {
|
||||
entries = entries[:util.SortAndUnique(entries)]
|
||||
nids := types.EventNIDs{} // zero slice to not store 'null' in the DB
|
||||
for _, e := range entries {
|
||||
nids = append(nids, e.EventNID)
|
||||
nids := make(types.EventNIDs, entries.Len())
|
||||
for i := range entries {
|
||||
nids[i] = entries[i].EventNID
|
||||
}
|
||||
js, err := json.Marshal(nids)
|
||||
if err != nil {
|
||||
|
|
@ -122,13 +121,13 @@ func (s *stateBlockStatements) BulkSelectStateBlockEntries(
|
|||
|
||||
results := make([][]types.EventNID, len(stateBlockNIDs))
|
||||
i := 0
|
||||
for ; rows.Next(); i++ {
|
||||
var stateBlockNID types.StateBlockNID
|
||||
var result json.RawMessage
|
||||
for ; rows.Next(); i++ {
|
||||
if err = rows.Scan(&stateBlockNID, &result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r := []types.EventNID{}
|
||||
var r []types.EventNID
|
||||
if err = json.Unmarshal(result, &r); err != nil {
|
||||
return nil, fmt.Errorf("json.Unmarshal: %w", err)
|
||||
}
|
||||
|
|
@ -142,35 +141,3 @@ func (s *stateBlockStatements) BulkSelectStateBlockEntries(
|
|||
}
|
||||
return results, err
|
||||
}
|
||||
|
||||
type stateKeyTupleSorter []types.StateKeyTuple
|
||||
|
||||
func (s stateKeyTupleSorter) Len() int { return len(s) }
|
||||
func (s stateKeyTupleSorter) Less(i, j int) bool { return s[i].LessThan(s[j]) }
|
||||
func (s stateKeyTupleSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
|
||||
|
||||
// Check whether a tuple is in the list. Assumes that the list is sorted.
|
||||
func (s stateKeyTupleSorter) contains(value types.StateKeyTuple) bool {
|
||||
i := sort.Search(len(s), func(i int) bool { return !s[i].LessThan(value) })
|
||||
return i < len(s) && s[i] == value
|
||||
}
|
||||
|
||||
// List the unique eventTypeNIDs and eventStateKeyNIDs.
|
||||
// Assumes that the list is sorted.
|
||||
func (s stateKeyTupleSorter) typesAndStateKeysAsArrays() (eventTypeNIDs []int64, eventStateKeyNIDs []int64) {
|
||||
eventTypeNIDs = make([]int64, len(s))
|
||||
eventStateKeyNIDs = make([]int64, len(s))
|
||||
for i := range s {
|
||||
eventTypeNIDs[i] = int64(s[i].EventTypeNID)
|
||||
eventStateKeyNIDs[i] = int64(s[i].EventStateKeyNID)
|
||||
}
|
||||
eventTypeNIDs = eventTypeNIDs[:util.SortAndUnique(int64Sorter(eventTypeNIDs))]
|
||||
eventStateKeyNIDs = eventStateKeyNIDs[:util.SortAndUnique(int64Sorter(eventStateKeyNIDs))]
|
||||
return
|
||||
}
|
||||
|
||||
type int64Sorter []int64
|
||||
|
||||
func (s int64Sorter) Len() int { return len(s) }
|
||||
func (s int64Sorter) Less(i, j int) bool { return s[i] < s[j] }
|
||||
func (s int64Sorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
|
||||
|
|
|
|||
|
|
@ -1,86 +0,0 @@
|
|||
// Copyright 2017-2018 New Vector Ltd
|
||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"testing"
|
||||
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
)
|
||||
|
||||
func TestStateKeyTupleSorter(t *testing.T) {
|
||||
input := stateKeyTupleSorter{
|
||||
{EventTypeNID: 1, EventStateKeyNID: 2},
|
||||
{EventTypeNID: 1, EventStateKeyNID: 4},
|
||||
{EventTypeNID: 2, EventStateKeyNID: 2},
|
||||
{EventTypeNID: 1, EventStateKeyNID: 1},
|
||||
}
|
||||
want := []types.StateKeyTuple{
|
||||
{EventTypeNID: 1, EventStateKeyNID: 1},
|
||||
{EventTypeNID: 1, EventStateKeyNID: 2},
|
||||
{EventTypeNID: 1, EventStateKeyNID: 4},
|
||||
{EventTypeNID: 2, EventStateKeyNID: 2},
|
||||
}
|
||||
doNotWant := []types.StateKeyTuple{
|
||||
{EventTypeNID: 0, EventStateKeyNID: 0},
|
||||
{EventTypeNID: 1, EventStateKeyNID: 3},
|
||||
{EventTypeNID: 2, EventStateKeyNID: 1},
|
||||
{EventTypeNID: 3, EventStateKeyNID: 1},
|
||||
}
|
||||
wantTypeNIDs := []int64{1, 2}
|
||||
wantStateKeyNIDs := []int64{1, 2, 4}
|
||||
|
||||
// Sort the input and check it's in the right order.
|
||||
sort.Sort(input)
|
||||
gotTypeNIDs, gotStateKeyNIDs := input.typesAndStateKeysAsArrays()
|
||||
|
||||
for i := range want {
|
||||
if input[i] != want[i] {
|
||||
t.Errorf("Wanted %#v at index %d got %#v", want[i], i, input[i])
|
||||
}
|
||||
|
||||
if !input.contains(want[i]) {
|
||||
t.Errorf("Wanted %#v.contains(%#v) to be true but got false", input, want[i])
|
||||
}
|
||||
}
|
||||
|
||||
for i := range doNotWant {
|
||||
if input.contains(doNotWant[i]) {
|
||||
t.Errorf("Wanted %#v.contains(%#v) to be false but got true", input, doNotWant[i])
|
||||
}
|
||||
}
|
||||
|
||||
if len(wantTypeNIDs) != len(gotTypeNIDs) {
|
||||
t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
|
||||
}
|
||||
|
||||
for i := range wantTypeNIDs {
|
||||
if wantTypeNIDs[i] != gotTypeNIDs[i] {
|
||||
t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
|
||||
}
|
||||
}
|
||||
|
||||
if len(wantStateKeyNIDs) != len(gotStateKeyNIDs) {
|
||||
t.Fatalf("Wanted state key NIDs %#v got %#v", wantStateKeyNIDs, gotStateKeyNIDs)
|
||||
}
|
||||
|
||||
for i := range wantStateKeyNIDs {
|
||||
if wantStateKeyNIDs[i] != gotStateKeyNIDs[i] {
|
||||
t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -68,12 +68,12 @@ type stateSnapshotStatements struct {
|
|||
bulkSelectStateBlockNIDsStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func createStateSnapshotTable(db *sql.DB) error {
|
||||
func CreateStateSnapshotTable(db *sql.DB) error {
|
||||
_, err := db.Exec(stateSnapshotSchema)
|
||||
return err
|
||||
}
|
||||
|
||||
func prepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) {
|
||||
func PrepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) {
|
||||
s := &stateSnapshotStatements{
|
||||
db: db,
|
||||
}
|
||||
|
|
@ -96,12 +96,10 @@ func (s *stateSnapshotStatements) InsertState(
|
|||
return
|
||||
}
|
||||
insertStmt := sqlutil.TxStmt(txn, s.insertStateStmt)
|
||||
var id int64
|
||||
err = insertStmt.QueryRowContext(ctx, stateBlockNIDs.Hash(), int64(roomNID), string(stateBlockNIDsJSON)).Scan(&id)
|
||||
err = insertStmt.QueryRowContext(ctx, stateBlockNIDs.Hash(), int64(roomNID), string(stateBlockNIDsJSON)).Scan(&stateNID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
stateNID = types.StateSnapshotNID(id)
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -127,9 +125,9 @@ func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs(
|
|||
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateBlockNIDs: rows.close() failed")
|
||||
results := make([]types.StateBlockNIDList, len(stateNIDs))
|
||||
i := 0
|
||||
var stateBlockNIDsJSON string
|
||||
for ; rows.Next(); i++ {
|
||||
result := &results[i]
|
||||
var stateBlockNIDsJSON string
|
||||
if err := rows.Scan(&result.StateSnapshotNID, &stateBlockNIDsJSON); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -89,19 +89,19 @@ func (d *Database) create(db *sql.DB) error {
|
|||
if err := CreateEventsTable(db); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := createRoomsTable(db); err != nil {
|
||||
if err := CreateRoomsTable(db); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := createStateBlockTable(db); err != nil {
|
||||
if err := CreateStateBlockTable(db); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := createStateSnapshotTable(db); err != nil {
|
||||
if err := CreateStateSnapshotTable(db); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := CreatePrevEventsTable(db); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := createRoomAliasesTable(db); err != nil {
|
||||
if err := CreateRoomAliasesTable(db); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := CreateInvitesTable(db); err != nil {
|
||||
|
|
@ -137,15 +137,15 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rooms, err := prepareRoomsTable(db)
|
||||
rooms, err := PrepareRoomsTable(db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
stateBlock, err := prepareStateBlockTable(db)
|
||||
stateBlock, err := PrepareStateBlockTable(db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
stateSnapshot, err := prepareStateSnapshotTable(db)
|
||||
stateSnapshot, err := PrepareStateSnapshotTable(db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -153,7 +153,7 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
roomAliases, err := prepareRoomAliasesTable(db)
|
||||
roomAliases, err := PrepareRoomAliasesTable(db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ func mustCreateEventsTable(t *testing.T, dbType test.DBType) (tables.Events, fun
|
|||
}
|
||||
|
||||
func Test_EventsTable(t *testing.T) {
|
||||
alice := test.NewUser()
|
||||
alice := test.NewUser(t)
|
||||
room := test.NewRoom(t, alice)
|
||||
ctx := context.Background()
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
|
|
|
|||
|
|
@ -72,7 +72,7 @@ type Rooms interface {
|
|||
UpdateLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID, stateSnapshotNID types.StateSnapshotNID) error
|
||||
SelectRoomVersionsForRoomNIDs(ctx context.Context, txn *sql.Tx, roomNID []types.RoomNID) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error)
|
||||
SelectRoomInfo(ctx context.Context, txn *sql.Tx, roomID string) (*types.RoomInfo, error)
|
||||
SelectRoomIDs(ctx context.Context, txn *sql.Tx) ([]string, error)
|
||||
SelectRoomIDsWithEvents(ctx context.Context, txn *sql.Tx) ([]string, error)
|
||||
BulkSelectRoomIDs(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) ([]string, error)
|
||||
BulkSelectRoomNIDs(ctx context.Context, txn *sql.Tx, roomIDs []string) ([]types.RoomNID, error)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ func mustCreatePreviousEventsTable(t *testing.T, dbType test.DBType) (tab tables
|
|||
|
||||
func TestPreviousEventsTable(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
alice := test.NewUser()
|
||||
alice := test.NewUser(t)
|
||||
room := test.NewRoom(t, alice)
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
tab, close := mustCreatePreviousEventsTable(t, dbType)
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ func mustCreatePublishedTable(t *testing.T, dbType test.DBType) (tab tables.Publ
|
|||
|
||||
func TestPublishedTable(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
alice := test.NewUser()
|
||||
alice := test.NewUser(t)
|
||||
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
tab, close := mustCreatePublishedTable(t, dbType)
|
||||
|
|
|
|||
96
roomserver/storage/tables/room_aliases_table_test.go
Normal file
96
roomserver/storage/tables/room_aliases_table_test.go
Normal file
|
|
@ -0,0 +1,96 @@
|
|||
package tables_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/postgres"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/sqlite3"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/test"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func mustCreateRoomAliasesTable(t *testing.T, dbType test.DBType) (tab tables.RoomAliases, close func()) {
|
||||
t.Helper()
|
||||
connStr, close := test.PrepareDBConnectionString(t, dbType)
|
||||
db, err := sqlutil.Open(&config.DatabaseOptions{
|
||||
ConnectionString: config.DataSource(connStr),
|
||||
}, sqlutil.NewExclusiveWriter())
|
||||
assert.NoError(t, err)
|
||||
switch dbType {
|
||||
case test.DBTypePostgres:
|
||||
err = postgres.CreateRoomAliasesTable(db)
|
||||
assert.NoError(t, err)
|
||||
tab, err = postgres.PrepareRoomAliasesTable(db)
|
||||
case test.DBTypeSQLite:
|
||||
err = sqlite3.CreateRoomAliasesTable(db)
|
||||
assert.NoError(t, err)
|
||||
tab, err = sqlite3.PrepareRoomAliasesTable(db)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
|
||||
return tab, close
|
||||
}
|
||||
|
||||
func TestRoomAliasesTable(t *testing.T) {
|
||||
alice := test.NewUser(t)
|
||||
room := test.NewRoom(t, alice)
|
||||
room2 := test.NewRoom(t, alice)
|
||||
ctx := context.Background()
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
tab, close := mustCreateRoomAliasesTable(t, dbType)
|
||||
defer close()
|
||||
alias, alias2, alias3 := "#alias:localhost", "#alias2:localhost", "#alias3:localhost"
|
||||
// insert aliases
|
||||
err := tab.InsertRoomAlias(ctx, nil, alias, room.ID, alice.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = tab.InsertRoomAlias(ctx, nil, alias2, room.ID, alice.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = tab.InsertRoomAlias(ctx, nil, alias3, room2.ID, alice.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// verify we can get the roomID for the alias
|
||||
roomID, err := tab.SelectRoomIDFromAlias(ctx, nil, alias)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, room.ID, roomID)
|
||||
|
||||
// .. and the creator
|
||||
creator, err := tab.SelectCreatorIDFromAlias(ctx, nil, alias)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, alice.ID, creator)
|
||||
|
||||
creator, err = tab.SelectCreatorIDFromAlias(ctx, nil, "#doesntexist:localhost")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "", creator)
|
||||
|
||||
roomID, err = tab.SelectRoomIDFromAlias(ctx, nil, "#doesntexist:localhost")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "", roomID)
|
||||
|
||||
// get all aliases for a room
|
||||
aliases, err := tab.SelectAliasesFromRoomID(ctx, nil, room.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{alias, alias2}, aliases)
|
||||
|
||||
// delete an alias and verify it's deleted
|
||||
err = tab.DeleteRoomAlias(ctx, nil, alias2)
|
||||
assert.NoError(t, err)
|
||||
|
||||
aliases, err = tab.SelectAliasesFromRoomID(ctx, nil, room.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{alias}, aliases)
|
||||
|
||||
// deleting the same alias should be a no-op
|
||||
err = tab.DeleteRoomAlias(ctx, nil, alias2)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Delete non-existent alias should be a no-op
|
||||
err = tab.DeleteRoomAlias(ctx, nil, "#doesntexist:localhost")
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
128
roomserver/storage/tables/rooms_table_test.go
Normal file
128
roomserver/storage/tables/rooms_table_test.go
Normal file
|
|
@ -0,0 +1,128 @@
|
|||
package tables_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/postgres"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/sqlite3"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/test"
|
||||
"github.com/matrix-org/util"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func mustCreateRoomsTable(t *testing.T, dbType test.DBType) (tab tables.Rooms, close func()) {
|
||||
t.Helper()
|
||||
connStr, close := test.PrepareDBConnectionString(t, dbType)
|
||||
db, err := sqlutil.Open(&config.DatabaseOptions{
|
||||
ConnectionString: config.DataSource(connStr),
|
||||
}, sqlutil.NewExclusiveWriter())
|
||||
assert.NoError(t, err)
|
||||
switch dbType {
|
||||
case test.DBTypePostgres:
|
||||
err = postgres.CreateRoomsTable(db)
|
||||
assert.NoError(t, err)
|
||||
tab, err = postgres.PrepareRoomsTable(db)
|
||||
case test.DBTypeSQLite:
|
||||
err = sqlite3.CreateRoomsTable(db)
|
||||
assert.NoError(t, err)
|
||||
tab, err = sqlite3.PrepareRoomsTable(db)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
|
||||
return tab, close
|
||||
}
|
||||
|
||||
func TestRoomsTable(t *testing.T) {
|
||||
alice := test.NewUser(t)
|
||||
room := test.NewRoom(t, alice)
|
||||
ctx := context.Background()
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
tab, close := mustCreateRoomsTable(t, dbType)
|
||||
defer close()
|
||||
|
||||
wantRoomNID, err := tab.InsertRoomNID(ctx, nil, room.ID, room.Version)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Create dummy room
|
||||
_, err = tab.InsertRoomNID(ctx, nil, util.RandomString(16), room.Version)
|
||||
assert.NoError(t, err)
|
||||
|
||||
gotRoomNID, err := tab.SelectRoomNID(ctx, nil, room.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, wantRoomNID, gotRoomNID)
|
||||
|
||||
// Ensure non existent roomNID errors
|
||||
roomNID, err := tab.SelectRoomNID(ctx, nil, "!doesnotexist:localhost")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, types.RoomNID(0), roomNID)
|
||||
|
||||
roomInfo, err := tab.SelectRoomInfo(ctx, nil, room.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, &types.RoomInfo{
|
||||
RoomNID: wantRoomNID,
|
||||
RoomVersion: room.Version,
|
||||
StateSnapshotNID: 0,
|
||||
IsStub: true, // there are no latestEventNIDs
|
||||
}, roomInfo)
|
||||
|
||||
roomInfo, err = tab.SelectRoomInfo(ctx, nil, "!doesnotexist:localhost")
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, roomInfo)
|
||||
|
||||
// There are no rooms with latestEventNIDs yet
|
||||
roomIDs, err := tab.SelectRoomIDsWithEvents(ctx, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 0, len(roomIDs))
|
||||
|
||||
roomVersions, err := tab.SelectRoomVersionsForRoomNIDs(ctx, nil, []types.RoomNID{wantRoomNID, 1337})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, roomVersions[wantRoomNID], room.Version)
|
||||
// Room does not exist
|
||||
_, ok := roomVersions[1337]
|
||||
assert.False(t, ok)
|
||||
|
||||
roomIDs, err = tab.BulkSelectRoomIDs(ctx, nil, []types.RoomNID{wantRoomNID, 1337})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{room.ID}, roomIDs)
|
||||
|
||||
roomNIDs, err := tab.BulkSelectRoomNIDs(ctx, nil, []string{room.ID, "!doesnotexist:localhost"})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []types.RoomNID{wantRoomNID}, roomNIDs)
|
||||
|
||||
wantEventNIDs := []types.EventNID{1, 2, 3}
|
||||
lastEventSentNID := types.EventNID(3)
|
||||
stateSnapshotNID := types.StateSnapshotNID(1)
|
||||
// make the room "usable"
|
||||
err = tab.UpdateLatestEventNIDs(ctx, nil, wantRoomNID, wantEventNIDs, lastEventSentNID, stateSnapshotNID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
roomInfo, err = tab.SelectRoomInfo(ctx, nil, room.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, &types.RoomInfo{
|
||||
RoomNID: wantRoomNID,
|
||||
RoomVersion: room.Version,
|
||||
StateSnapshotNID: 1,
|
||||
IsStub: false,
|
||||
}, roomInfo)
|
||||
|
||||
eventNIDs, snapshotNID, err := tab.SelectLatestEventNIDs(ctx, nil, wantRoomNID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, wantEventNIDs, eventNIDs)
|
||||
assert.Equal(t, types.StateSnapshotNID(1), snapshotNID)
|
||||
|
||||
// Again, doesn't exist
|
||||
_, _, err = tab.SelectLatestEventNIDs(ctx, nil, 1337)
|
||||
assert.Error(t, err)
|
||||
|
||||
eventNIDs, eventNID, snapshotNID, err := tab.SelectLatestEventsNIDsForUpdate(ctx, nil, wantRoomNID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, wantEventNIDs, eventNIDs)
|
||||
assert.Equal(t, types.EventNID(3), eventNID)
|
||||
assert.Equal(t, types.StateSnapshotNID(1), snapshotNID)
|
||||
})
|
||||
}
|
||||
92
roomserver/storage/tables/state_block_table_test.go
Normal file
92
roomserver/storage/tables/state_block_table_test.go
Normal file
|
|
@ -0,0 +1,92 @@
|
|||
package tables_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/postgres"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/sqlite3"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/test"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func mustCreateStateBlockTable(t *testing.T, dbType test.DBType) (tab tables.StateBlock, close func()) {
|
||||
t.Helper()
|
||||
connStr, close := test.PrepareDBConnectionString(t, dbType)
|
||||
db, err := sqlutil.Open(&config.DatabaseOptions{
|
||||
ConnectionString: config.DataSource(connStr),
|
||||
}, sqlutil.NewExclusiveWriter())
|
||||
assert.NoError(t, err)
|
||||
switch dbType {
|
||||
case test.DBTypePostgres:
|
||||
err = postgres.CreateStateBlockTable(db)
|
||||
assert.NoError(t, err)
|
||||
tab, err = postgres.PrepareStateBlockTable(db)
|
||||
case test.DBTypeSQLite:
|
||||
err = sqlite3.CreateStateBlockTable(db)
|
||||
assert.NoError(t, err)
|
||||
tab, err = sqlite3.PrepareStateBlockTable(db)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
|
||||
return tab, close
|
||||
}
|
||||
|
||||
func TestStateBlockTable(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
tab, close := mustCreateStateBlockTable(t, dbType)
|
||||
defer close()
|
||||
|
||||
// generate some dummy data
|
||||
var entries types.StateEntries
|
||||
for i := 0; i < 100; i++ {
|
||||
entry := types.StateEntry{
|
||||
EventNID: types.EventNID(i),
|
||||
}
|
||||
entries = append(entries, entry)
|
||||
}
|
||||
stateBlockNID, err := tab.BulkInsertStateData(ctx, nil, entries)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, types.StateBlockNID(1), stateBlockNID)
|
||||
|
||||
// generate a different hash, to get a new StateBlockNID
|
||||
var entries2 types.StateEntries
|
||||
for i := 100; i < 300; i++ {
|
||||
entry := types.StateEntry{
|
||||
EventNID: types.EventNID(i),
|
||||
}
|
||||
entries2 = append(entries2, entry)
|
||||
}
|
||||
stateBlockNID, err = tab.BulkInsertStateData(ctx, nil, entries2)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, types.StateBlockNID(2), stateBlockNID)
|
||||
|
||||
eventNIDs, err := tab.BulkSelectStateBlockEntries(ctx, nil, types.StateBlockNIDs{1, 2})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, len(entries), len(eventNIDs[0]))
|
||||
assert.Equal(t, len(entries2), len(eventNIDs[1]))
|
||||
|
||||
// try to get a StateBlockNID which does not exist
|
||||
_, err = tab.BulkSelectStateBlockEntries(ctx, nil, types.StateBlockNIDs{5})
|
||||
assert.Error(t, err)
|
||||
|
||||
// This should return an error, since we can only retrieve 1 StateBlock
|
||||
_, err = tab.BulkSelectStateBlockEntries(ctx, nil, types.StateBlockNIDs{1, 5})
|
||||
assert.Error(t, err)
|
||||
|
||||
for i := 0; i < 65555; i++ {
|
||||
entry := types.StateEntry{
|
||||
EventNID: types.EventNID(i),
|
||||
}
|
||||
entries2 = append(entries2, entry)
|
||||
}
|
||||
stateBlockNID, err = tab.BulkInsertStateData(ctx, nil, entries2)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, types.StateBlockNID(3), stateBlockNID)
|
||||
})
|
||||
}
|
||||
86
roomserver/storage/tables/state_snapshot_table_test.go
Normal file
86
roomserver/storage/tables/state_snapshot_table_test.go
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
package tables_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/postgres"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/sqlite3"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/test"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func mustCreateStateSnapshotTable(t *testing.T, dbType test.DBType) (tab tables.StateSnapshot, close func()) {
|
||||
t.Helper()
|
||||
connStr, close := test.PrepareDBConnectionString(t, dbType)
|
||||
db, err := sqlutil.Open(&config.DatabaseOptions{
|
||||
ConnectionString: config.DataSource(connStr),
|
||||
}, sqlutil.NewExclusiveWriter())
|
||||
assert.NoError(t, err)
|
||||
switch dbType {
|
||||
case test.DBTypePostgres:
|
||||
err = postgres.CreateStateSnapshotTable(db)
|
||||
assert.NoError(t, err)
|
||||
tab, err = postgres.PrepareStateSnapshotTable(db)
|
||||
case test.DBTypeSQLite:
|
||||
err = sqlite3.CreateStateSnapshotTable(db)
|
||||
assert.NoError(t, err)
|
||||
tab, err = sqlite3.PrepareStateSnapshotTable(db)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
|
||||
return tab, close
|
||||
}
|
||||
|
||||
func TestStateSnapshotTable(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
tab, close := mustCreateStateSnapshotTable(t, dbType)
|
||||
defer close()
|
||||
|
||||
// generate some dummy data
|
||||
var stateBlockNIDs types.StateBlockNIDs
|
||||
for i := 0; i < 100; i++ {
|
||||
stateBlockNIDs = append(stateBlockNIDs, types.StateBlockNID(i))
|
||||
}
|
||||
stateNID, err := tab.InsertState(ctx, nil, 1, stateBlockNIDs)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, types.StateSnapshotNID(1), stateNID)
|
||||
|
||||
// verify ON CONFLICT; Note: this updates the sequence!
|
||||
stateNID, err = tab.InsertState(ctx, nil, 1, stateBlockNIDs)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, types.StateSnapshotNID(1), stateNID)
|
||||
|
||||
// create a second snapshot
|
||||
var stateBlockNIDs2 types.StateBlockNIDs
|
||||
for i := 100; i < 150; i++ {
|
||||
stateBlockNIDs2 = append(stateBlockNIDs2, types.StateBlockNID(i))
|
||||
}
|
||||
|
||||
stateNID, err = tab.InsertState(ctx, nil, 1, stateBlockNIDs2)
|
||||
assert.NoError(t, err)
|
||||
// StateSnapshotNID is now 3, since the DO UPDATE SET statement incremented the sequence
|
||||
assert.Equal(t, types.StateSnapshotNID(3), stateNID)
|
||||
|
||||
nidLists, err := tab.BulkSelectStateBlockNIDs(ctx, nil, []types.StateSnapshotNID{1, 3})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, stateBlockNIDs, types.StateBlockNIDs(nidLists[0].StateBlockNIDs))
|
||||
assert.Equal(t, stateBlockNIDs2, types.StateBlockNIDs(nidLists[1].StateBlockNIDs))
|
||||
|
||||
// check we get an error if the state snapshot does not exist
|
||||
_, err = tab.BulkSelectStateBlockNIDs(ctx, nil, []types.StateSnapshotNID{2})
|
||||
assert.Error(t, err)
|
||||
|
||||
// create a second snapshot
|
||||
for i := 0; i < 65555; i++ {
|
||||
stateBlockNIDs2 = append(stateBlockNIDs2, types.StateBlockNID(i))
|
||||
}
|
||||
_, err = tab.InsertState(ctx, nil, 1, stateBlockNIDs2)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
|
@ -21,6 +21,7 @@ import (
|
|||
"strings"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
"golang.org/x/crypto/blake2b"
|
||||
)
|
||||
|
||||
|
|
@ -97,6 +98,38 @@ func (a StateKeyTuple) LessThan(b StateKeyTuple) bool {
|
|||
return a.EventStateKeyNID < b.EventStateKeyNID
|
||||
}
|
||||
|
||||
type StateKeyTupleSorter []StateKeyTuple
|
||||
|
||||
func (s StateKeyTupleSorter) Len() int { return len(s) }
|
||||
func (s StateKeyTupleSorter) Less(i, j int) bool { return s[i].LessThan(s[j]) }
|
||||
func (s StateKeyTupleSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
|
||||
|
||||
// Check whether a tuple is in the list. Assumes that the list is sorted.
|
||||
func (s StateKeyTupleSorter) contains(value StateKeyTuple) bool {
|
||||
i := sort.Search(len(s), func(i int) bool { return !s[i].LessThan(value) })
|
||||
return i < len(s) && s[i] == value
|
||||
}
|
||||
|
||||
// List the unique eventTypeNIDs and eventStateKeyNIDs.
|
||||
// Assumes that the list is sorted.
|
||||
func (s StateKeyTupleSorter) TypesAndStateKeysAsArrays() (eventTypeNIDs []int64, eventStateKeyNIDs []int64) {
|
||||
eventTypeNIDs = make([]int64, len(s))
|
||||
eventStateKeyNIDs = make([]int64, len(s))
|
||||
for i := range s {
|
||||
eventTypeNIDs[i] = int64(s[i].EventTypeNID)
|
||||
eventStateKeyNIDs[i] = int64(s[i].EventStateKeyNID)
|
||||
}
|
||||
eventTypeNIDs = eventTypeNIDs[:util.SortAndUnique(int64Sorter(eventTypeNIDs))]
|
||||
eventStateKeyNIDs = eventStateKeyNIDs[:util.SortAndUnique(int64Sorter(eventStateKeyNIDs))]
|
||||
return
|
||||
}
|
||||
|
||||
type int64Sorter []int64
|
||||
|
||||
func (s int64Sorter) Len() int { return len(s) }
|
||||
func (s int64Sorter) Less(i, j int) bool { return s[i] < s[j] }
|
||||
func (s int64Sorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
|
||||
|
||||
// A StateEntry is an entry in the room state of a matrix room.
|
||||
type StateEntry struct {
|
||||
StateKeyTuple
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
package types
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"testing"
|
||||
)
|
||||
|
||||
|
|
@ -24,3 +25,66 @@ func TestDeduplicateStateEntries(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestStateKeyTupleSorter(t *testing.T) {
|
||||
input := StateKeyTupleSorter{
|
||||
{EventTypeNID: 1, EventStateKeyNID: 2},
|
||||
{EventTypeNID: 1, EventStateKeyNID: 4},
|
||||
{EventTypeNID: 2, EventStateKeyNID: 2},
|
||||
{EventTypeNID: 1, EventStateKeyNID: 1},
|
||||
}
|
||||
want := []StateKeyTuple{
|
||||
{EventTypeNID: 1, EventStateKeyNID: 1},
|
||||
{EventTypeNID: 1, EventStateKeyNID: 2},
|
||||
{EventTypeNID: 1, EventStateKeyNID: 4},
|
||||
{EventTypeNID: 2, EventStateKeyNID: 2},
|
||||
}
|
||||
doNotWant := []StateKeyTuple{
|
||||
{EventTypeNID: 0, EventStateKeyNID: 0},
|
||||
{EventTypeNID: 1, EventStateKeyNID: 3},
|
||||
{EventTypeNID: 2, EventStateKeyNID: 1},
|
||||
{EventTypeNID: 3, EventStateKeyNID: 1},
|
||||
}
|
||||
wantTypeNIDs := []int64{1, 2}
|
||||
wantStateKeyNIDs := []int64{1, 2, 4}
|
||||
|
||||
// Sort the input and check it's in the right order.
|
||||
sort.Sort(input)
|
||||
gotTypeNIDs, gotStateKeyNIDs := input.TypesAndStateKeysAsArrays()
|
||||
|
||||
for i := range want {
|
||||
if input[i] != want[i] {
|
||||
t.Errorf("Wanted %#v at index %d got %#v", want[i], i, input[i])
|
||||
}
|
||||
|
||||
if !input.contains(want[i]) {
|
||||
t.Errorf("Wanted %#v.contains(%#v) to be true but got false", input, want[i])
|
||||
}
|
||||
}
|
||||
|
||||
for i := range doNotWant {
|
||||
if input.contains(doNotWant[i]) {
|
||||
t.Errorf("Wanted %#v.contains(%#v) to be false but got true", input, doNotWant[i])
|
||||
}
|
||||
}
|
||||
|
||||
if len(wantTypeNIDs) != len(gotTypeNIDs) {
|
||||
t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
|
||||
}
|
||||
|
||||
for i := range wantTypeNIDs {
|
||||
if wantTypeNIDs[i] != gotTypeNIDs[i] {
|
||||
t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
|
||||
}
|
||||
}
|
||||
|
||||
if len(wantStateKeyNIDs) != len(gotStateKeyNIDs) {
|
||||
t.Fatalf("Wanted state key NIDs %#v got %#v", wantStateKeyNIDs, gotStateKeyNIDs)
|
||||
}
|
||||
|
||||
for i := range wantStateKeyNIDs {
|
||||
if wantStateKeyNIDs[i] != gotStateKeyNIDs[i] {
|
||||
t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -138,9 +138,12 @@ func (s *PresenceConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
|
|||
presence := msg.Header.Get("presence")
|
||||
timestamp := msg.Header.Get("last_active_ts")
|
||||
fromSync, _ := strconv.ParseBool(msg.Header.Get("from_sync"))
|
||||
|
||||
logrus.Debugf("syncAPI received presence event: %+v", msg.Header)
|
||||
|
||||
if fromSync { // do not process local presence changes; we already did this synchronously.
|
||||
return true
|
||||
}
|
||||
|
||||
ts, err := strconv.Atoi(timestamp)
|
||||
if err != nil {
|
||||
return true
|
||||
|
|
@ -151,15 +154,19 @@ func (s *PresenceConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
|
|||
newMsg := msg.Header.Get("status_msg")
|
||||
statusMsg = &newMsg
|
||||
}
|
||||
// OK is already checked, so no need to do it again
|
||||
// already checked, so no need to check error
|
||||
p, _ := types.PresenceFromString(presence)
|
||||
pos, err := s.db.UpdatePresence(ctx, userID, p, statusMsg, gomatrixserverlib.Timestamp(ts), fromSync)
|
||||
if err != nil {
|
||||
return true
|
||||
}
|
||||
|
||||
s.stream.Advance(pos)
|
||||
s.notifier.OnNewPresence(types.StreamingToken{PresencePosition: pos}, userID)
|
||||
|
||||
s.EmitPresence(ctx, userID, p, statusMsg, ts, fromSync)
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *PresenceConsumer) EmitPresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, ts int, fromSync bool) {
|
||||
pos, err := s.db.UpdatePresence(ctx, userID, presence, statusMsg, gomatrixserverlib.Timestamp(ts), fromSync)
|
||||
if err != nil {
|
||||
logrus.WithError(err).WithField("user", userID).WithField("presence", presence).Warn("failed to updated presence for user")
|
||||
return
|
||||
}
|
||||
s.stream.Advance(pos)
|
||||
s.notifier.OnNewPresence(types.StreamingToken{PresencePosition: pos}, userID)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ func MustWriteEvents(t *testing.T, db storage.Database, events []*gomatrixserver
|
|||
|
||||
func TestWriteEvents(t *testing.T) {
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
alice := test.NewUser()
|
||||
alice := test.NewUser(t)
|
||||
r := test.NewRoom(t, alice)
|
||||
db, close := MustCreateDatabase(t, dbType)
|
||||
defer close()
|
||||
|
|
@ -60,7 +60,7 @@ func TestRecentEventsPDU(t *testing.T) {
|
|||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := MustCreateDatabase(t, dbType)
|
||||
defer close()
|
||||
alice := test.NewUser()
|
||||
alice := test.NewUser(t)
|
||||
// dummy room to make sure SQL queries are filtering on room ID
|
||||
MustWriteEvents(t, db, test.NewRoom(t, alice).Events())
|
||||
|
||||
|
|
@ -163,7 +163,7 @@ func TestGetEventsInRangeWithTopologyToken(t *testing.T) {
|
|||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := MustCreateDatabase(t, dbType)
|
||||
defer close()
|
||||
alice := test.NewUser()
|
||||
alice := test.NewUser(t)
|
||||
r := test.NewRoom(t, alice)
|
||||
for i := 0; i < 10; i++ {
|
||||
r.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": fmt.Sprintf("hi %d", i)})
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ func newOutputRoomEventsTable(t *testing.T, dbType test.DBType) (tables.Events,
|
|||
|
||||
func TestOutputRoomEventsTable(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
alice := test.NewUser()
|
||||
alice := test.NewUser(t)
|
||||
room := test.NewRoom(t, alice)
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
tab, db, close := newOutputRoomEventsTable(t, dbType)
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ func newTopologyTable(t *testing.T, dbType test.DBType) (tables.Topology, *sql.D
|
|||
|
||||
func TestTopologyTable(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
alice := test.NewUser()
|
||||
alice := test.NewUser(t)
|
||||
room := test.NewRoom(t, alice)
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
tab, db, close := newTopologyTable(t, dbType)
|
||||
|
|
|
|||
|
|
@ -53,19 +53,24 @@ type RequestPool struct {
|
|||
streams *streams.Streams
|
||||
Notifier *notifier.Notifier
|
||||
producer PresencePublisher
|
||||
consumer PresenceConsumer
|
||||
}
|
||||
|
||||
type PresencePublisher interface {
|
||||
SendPresence(userID string, presence types.Presence, statusMsg *string) error
|
||||
}
|
||||
|
||||
type PresenceConsumer interface {
|
||||
EmitPresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, ts int, fromSync bool)
|
||||
}
|
||||
|
||||
// NewRequestPool makes a new RequestPool
|
||||
func NewRequestPool(
|
||||
db storage.Database, cfg *config.SyncAPI,
|
||||
userAPI userapi.SyncUserAPI, keyAPI keyapi.SyncKeyAPI,
|
||||
rsAPI roomserverAPI.SyncRoomserverAPI,
|
||||
streams *streams.Streams, notifier *notifier.Notifier,
|
||||
producer PresencePublisher, enableMetrics bool,
|
||||
producer PresencePublisher, consumer PresenceConsumer, enableMetrics bool,
|
||||
) *RequestPool {
|
||||
if enableMetrics {
|
||||
prometheus.MustRegister(
|
||||
|
|
@ -83,6 +88,7 @@ func NewRequestPool(
|
|||
streams: streams,
|
||||
Notifier: notifier,
|
||||
producer: producer,
|
||||
consumer: consumer,
|
||||
}
|
||||
go rp.cleanLastSeen()
|
||||
go rp.cleanPresence(db, time.Minute*5)
|
||||
|
|
@ -160,6 +166,13 @@ func (rp *RequestPool) updatePresence(db storage.Presence, presence string, user
|
|||
logrus.WithError(err).Error("Unable to publish presence message from sync")
|
||||
return
|
||||
}
|
||||
|
||||
// now synchronously update our view of the world. It's critical we do this before calculating
|
||||
// the /sync response else we may not return presence: online immediately.
|
||||
rp.consumer.EmitPresence(
|
||||
context.Background(), userID, presenceID, newPresence.ClientFields.StatusMsg,
|
||||
int(gomatrixserverlib.AsTimestamp(time.Now())), true,
|
||||
)
|
||||
}
|
||||
|
||||
func (rp *RequestPool) updateLastSeen(req *http.Request, device *userapi.Device) {
|
||||
|
|
@ -238,8 +251,12 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi.
|
|||
waitingSyncRequests.Inc()
|
||||
defer waitingSyncRequests.Dec()
|
||||
|
||||
// loop until we get some data
|
||||
for {
|
||||
startTime := time.Now()
|
||||
currentPos := rp.Notifier.CurrentPosition()
|
||||
|
||||
// if the since token matches the current positions, wait via the notifier
|
||||
if !rp.shouldReturnImmediately(syncReq, currentPos) {
|
||||
timer := time.NewTimer(syncReq.Timeout) // case of timeout=0 is handled above
|
||||
defer timer.Stop()
|
||||
|
|
@ -352,12 +369,34 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi.
|
|||
syncReq.Since.PresencePosition, currentPos.PresencePosition,
|
||||
),
|
||||
}
|
||||
// it's possible for there to be no updates for this user even though since < current pos,
|
||||
// e.g busy servers with a quiet user. In this scenario, we don't want to return a no-op
|
||||
// response immediately, so let's try this again but pretend they bumped their since token.
|
||||
// If the incremental sync was processed very quickly then we expect the next loop to block
|
||||
// with a notifier, but if things are slow it's entirely possible that currentPos is no
|
||||
// longer the current position so we will hit this code path again. We need to do this and
|
||||
// not return a no-op response because:
|
||||
// - It's an inefficient use of bandwidth.
|
||||
// - Some sytests which test 'waking up' sync rely on some sync requests to block, which
|
||||
// they weren't always doing, resulting in flakey tests.
|
||||
if !syncReq.Response.HasUpdates() {
|
||||
syncReq.Since = currentPos
|
||||
// do not loop again if the ?timeout= is 0 as that means "return immediately"
|
||||
if syncReq.Timeout > 0 {
|
||||
syncReq.Timeout = syncReq.Timeout - time.Since(startTime)
|
||||
if syncReq.Timeout < 0 {
|
||||
syncReq.Timeout = 0
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusOK,
|
||||
JSON: syncReq.Response,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (rp *RequestPool) OnIncomingKeyChangeRequest(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
|
|
|
|||
|
|
@ -38,6 +38,12 @@ func (d dummyDB) MaxStreamPositionForPresence(ctx context.Context) (types.Stream
|
|||
return 0, nil
|
||||
}
|
||||
|
||||
type dummyConsumer struct{}
|
||||
|
||||
func (d dummyConsumer) EmitPresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, ts int, fromSync bool) {
|
||||
|
||||
}
|
||||
|
||||
func TestRequestPool_updatePresence(t *testing.T) {
|
||||
type args struct {
|
||||
presence string
|
||||
|
|
@ -45,6 +51,7 @@ func TestRequestPool_updatePresence(t *testing.T) {
|
|||
sleep time.Duration
|
||||
}
|
||||
publisher := &dummyPublisher{}
|
||||
consumer := &dummyConsumer{}
|
||||
syncMap := sync.Map{}
|
||||
|
||||
tests := []struct {
|
||||
|
|
@ -101,6 +108,7 @@ func TestRequestPool_updatePresence(t *testing.T) {
|
|||
rp := &RequestPool{
|
||||
presence: &syncMap,
|
||||
producer: publisher,
|
||||
consumer: consumer,
|
||||
cfg: &config.SyncAPI{
|
||||
Matrix: &config.Global{
|
||||
JetStream: config.JetStream{
|
||||
|
|
|
|||
|
|
@ -64,8 +64,17 @@ func AddPublicRoutes(
|
|||
Topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputPresenceEvent),
|
||||
JetStream: js,
|
||||
}
|
||||
presenceConsumer := consumers.NewPresenceConsumer(
|
||||
base.ProcessContext, cfg, js, natsClient, syncDB,
|
||||
notifier, streams.PresenceStreamProvider,
|
||||
userAPI,
|
||||
)
|
||||
|
||||
requestPool := sync.NewRequestPool(syncDB, cfg, userAPI, keyAPI, rsAPI, streams, notifier, federationPresenceProducer, base.EnableMetrics)
|
||||
requestPool := sync.NewRequestPool(syncDB, cfg, userAPI, keyAPI, rsAPI, streams, notifier, federationPresenceProducer, presenceConsumer, base.EnableMetrics)
|
||||
|
||||
if err = presenceConsumer.Start(); err != nil {
|
||||
logrus.WithError(err).Panicf("failed to start presence consumer")
|
||||
}
|
||||
|
||||
userAPIStreamEventProducer := &producers.UserAPIStreamEventProducer{
|
||||
JetStream: js,
|
||||
|
|
@ -131,15 +140,6 @@ func AddPublicRoutes(
|
|||
logrus.WithError(err).Panicf("failed to start receipts consumer")
|
||||
}
|
||||
|
||||
presenceConsumer := consumers.NewPresenceConsumer(
|
||||
base.ProcessContext, cfg, js, natsClient, syncDB,
|
||||
notifier, streams.PresenceStreamProvider,
|
||||
userAPI,
|
||||
)
|
||||
if err = presenceConsumer.Start(); err != nil {
|
||||
logrus.WithError(err).Panicf("failed to start presence consumer")
|
||||
}
|
||||
|
||||
routing.Setup(
|
||||
base.PublicClientAPIMux, requestPool, syncDB, userAPI,
|
||||
rsAPI, cfg, base.Caches,
|
||||
|
|
|
|||
|
|
@ -15,9 +15,11 @@ import (
|
|||
"github.com/matrix-org/dendrite/setup/jetstream"
|
||||
"github.com/matrix-org/dendrite/syncapi/types"
|
||||
"github.com/matrix-org/dendrite/test"
|
||||
"github.com/matrix-org/dendrite/test/testrig"
|
||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/nats-io/nats.go"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
type syncRoomserverAPI struct {
|
||||
|
|
@ -86,7 +88,7 @@ func TestSyncAPIAccessTokens(t *testing.T) {
|
|||
}
|
||||
|
||||
func testSyncAccessTokens(t *testing.T, dbType test.DBType) {
|
||||
user := test.NewUser()
|
||||
user := test.NewUser(t)
|
||||
room := test.NewRoom(t, user)
|
||||
alice := userapi.Device{
|
||||
ID: "ALICEID",
|
||||
|
|
@ -96,14 +98,14 @@ func testSyncAccessTokens(t *testing.T, dbType test.DBType) {
|
|||
AccountType: userapi.AccountTypeUser,
|
||||
}
|
||||
|
||||
base, close := test.CreateBaseDendrite(t, dbType)
|
||||
base, close := testrig.CreateBaseDendrite(t, dbType)
|
||||
defer close()
|
||||
|
||||
jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream)
|
||||
defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream)
|
||||
msgs := toNATSMsgs(t, base, room.Events())
|
||||
AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, &syncKeyAPI{})
|
||||
test.MustPublishMsgs(t, jsctx, msgs...)
|
||||
testrig.MustPublishMsgs(t, jsctx, msgs...)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
|
|
@ -173,7 +175,7 @@ func TestSyncAPICreateRoomSyncEarly(t *testing.T) {
|
|||
}
|
||||
|
||||
func testSyncAPICreateRoomSyncEarly(t *testing.T, dbType test.DBType) {
|
||||
user := test.NewUser()
|
||||
user := test.NewUser(t)
|
||||
room := test.NewRoom(t, user)
|
||||
alice := userapi.Device{
|
||||
ID: "ALICEID",
|
||||
|
|
@ -183,7 +185,7 @@ func testSyncAPICreateRoomSyncEarly(t *testing.T, dbType test.DBType) {
|
|||
AccountType: userapi.AccountTypeUser,
|
||||
}
|
||||
|
||||
base, close := test.CreateBaseDendrite(t, dbType)
|
||||
base, close := testrig.CreateBaseDendrite(t, dbType)
|
||||
defer close()
|
||||
|
||||
jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream)
|
||||
|
|
@ -198,7 +200,7 @@ func testSyncAPICreateRoomSyncEarly(t *testing.T, dbType test.DBType) {
|
|||
sinceTokens := make([]string, len(msgs))
|
||||
AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, &syncKeyAPI{})
|
||||
for i, msg := range msgs {
|
||||
test.MustPublishMsgs(t, jsctx, msg)
|
||||
testrig.MustPublishMsgs(t, jsctx, msg)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
w := httptest.NewRecorder()
|
||||
base.PublicClientAPIMux.ServeHTTP(w, test.NewRequest(t, "GET", "/_matrix/client/v3/sync", test.WithQueryParams(map[string]string{
|
||||
|
|
@ -255,6 +257,60 @@ func testSyncAPICreateRoomSyncEarly(t *testing.T, dbType test.DBType) {
|
|||
}
|
||||
}
|
||||
|
||||
// Test that if we hit /sync we get back presence: online, regardless of whether messages get delivered
|
||||
// via NATS. Regression test for a flakey test "User sees their own presence in a sync"
|
||||
func TestSyncAPIUpdatePresenceImmediately(t *testing.T) {
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
testSyncAPIUpdatePresenceImmediately(t, dbType)
|
||||
})
|
||||
}
|
||||
|
||||
func testSyncAPIUpdatePresenceImmediately(t *testing.T, dbType test.DBType) {
|
||||
user := test.NewUser(t)
|
||||
alice := userapi.Device{
|
||||
ID: "ALICEID",
|
||||
UserID: user.ID,
|
||||
AccessToken: "ALICE_BEARER_TOKEN",
|
||||
DisplayName: "Alice",
|
||||
AccountType: userapi.AccountTypeUser,
|
||||
}
|
||||
|
||||
base, close := testrig.CreateBaseDendrite(t, dbType)
|
||||
base.Cfg.Global.Presence.EnableOutbound = true
|
||||
base.Cfg.Global.Presence.EnableInbound = true
|
||||
defer close()
|
||||
|
||||
jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream)
|
||||
defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream)
|
||||
AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{}, &syncKeyAPI{})
|
||||
w := httptest.NewRecorder()
|
||||
base.PublicClientAPIMux.ServeHTTP(w, test.NewRequest(t, "GET", "/_matrix/client/v3/sync", test.WithQueryParams(map[string]string{
|
||||
"access_token": alice.AccessToken,
|
||||
"timeout": "0",
|
||||
"set_presence": "online",
|
||||
})))
|
||||
if w.Code != 200 {
|
||||
t.Fatalf("got HTTP %d want %d", w.Code, 200)
|
||||
}
|
||||
var res types.Response
|
||||
if err := json.NewDecoder(w.Body).Decode(&res); err != nil {
|
||||
t.Errorf("failed to decode response body: %s", err)
|
||||
}
|
||||
if len(res.Presence.Events) != 1 {
|
||||
t.Fatalf("expected 1 presence events, got: %+v", res.Presence.Events)
|
||||
}
|
||||
if res.Presence.Events[0].Sender != alice.UserID {
|
||||
t.Errorf("sender: got %v want %v", res.Presence.Events[0].Sender, alice.UserID)
|
||||
}
|
||||
if res.Presence.Events[0].Type != "m.presence" {
|
||||
t.Errorf("type: got %v want %v", res.Presence.Events[0].Type, "m.presence")
|
||||
}
|
||||
if gjson.ParseBytes(res.Presence.Events[0].Content).Get("presence").Str != "online" {
|
||||
t.Errorf("content: not online, got %v", res.Presence.Events[0].Content)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func toNATSMsgs(t *testing.T, base *base.BaseDendrite, input []*gomatrixserverlib.HeaderedEvent) []*nats.Msg {
|
||||
result := make([]*nats.Msg, len(input))
|
||||
for i, ev := range input {
|
||||
|
|
@ -262,7 +318,7 @@ func toNATSMsgs(t *testing.T, base *base.BaseDendrite, input []*gomatrixserverli
|
|||
if ev.StateKey() != nil {
|
||||
addsStateIDs = append(addsStateIDs, ev.EventID())
|
||||
}
|
||||
result[i] = test.NewOutputEventMsg(t, base, ev.RoomID(), api.OutputEvent{
|
||||
result[i] = testrig.NewOutputEventMsg(t, base, ev.RoomID(), api.OutputEvent{
|
||||
Type: rsapi.OutputTypeNewRoomEvent,
|
||||
NewRoomEvent: &rsapi.OutputNewRoomEvent{
|
||||
Event: ev,
|
||||
|
|
|
|||
|
|
@ -350,6 +350,19 @@ type Response struct {
|
|||
DeviceListsOTKCount map[string]int `json:"device_one_time_keys_count,omitempty"`
|
||||
}
|
||||
|
||||
func (r *Response) HasUpdates() bool {
|
||||
// purposefully exclude DeviceListsOTKCount as we always include them
|
||||
return (len(r.AccountData.Events) > 0 ||
|
||||
len(r.Presence.Events) > 0 ||
|
||||
len(r.Rooms.Invite) > 0 ||
|
||||
len(r.Rooms.Join) > 0 ||
|
||||
len(r.Rooms.Leave) > 0 ||
|
||||
len(r.Rooms.Peek) > 0 ||
|
||||
len(r.ToDevice.Events) > 0 ||
|
||||
len(r.DeviceLists.Changed) > 0 ||
|
||||
len(r.DeviceLists.Left) > 0)
|
||||
}
|
||||
|
||||
// NewResponse creates an empty response with initialised maps.
|
||||
func NewResponse() *Response {
|
||||
res := Response{}
|
||||
|
|
|
|||
|
|
@ -44,8 +44,9 @@ func fatalError(t *testing.T, format string, args ...interface{}) {
|
|||
}
|
||||
|
||||
func createLocalDB(t *testing.T, dbName string) {
|
||||
if !Quiet {
|
||||
t.Log("Note: tests require a postgres install accessible to the current user")
|
||||
if _, err := exec.LookPath("createdb"); err != nil {
|
||||
fatalError(t, "Note: tests require a postgres install accessible to the current user")
|
||||
return
|
||||
}
|
||||
createDB := exec.Command("createdb", dbName)
|
||||
if !Quiet {
|
||||
|
|
@ -63,6 +64,9 @@ func createRemoteDB(t *testing.T, dbName, user, connStr string) {
|
|||
if err != nil {
|
||||
fatalError(t, "failed to open postgres conn with connstr=%s : %s", connStr, err)
|
||||
}
|
||||
if err = db.Ping(); err != nil {
|
||||
fatalError(t, "failed to open postgres conn with connstr=%s : %s", connStr, err)
|
||||
}
|
||||
_, err = db.Exec(fmt.Sprintf(`CREATE DATABASE %s;`, dbName))
|
||||
if err != nil {
|
||||
pqErr, ok := err.(*pq.Error)
|
||||
|
|
|
|||
|
|
@ -52,6 +52,24 @@ func WithUnsigned(unsigned interface{}) eventModifier {
|
|||
}
|
||||
}
|
||||
|
||||
func WithKeyID(keyID gomatrixserverlib.KeyID) eventModifier {
|
||||
return func(e *eventMods) {
|
||||
e.keyID = keyID
|
||||
}
|
||||
}
|
||||
|
||||
func WithPrivateKey(pkey ed25519.PrivateKey) eventModifier {
|
||||
return func(e *eventMods) {
|
||||
e.privKey = pkey
|
||||
}
|
||||
}
|
||||
|
||||
func WithOrigin(origin gomatrixserverlib.ServerName) eventModifier {
|
||||
return func(e *eventMods) {
|
||||
e.origin = origin
|
||||
}
|
||||
}
|
||||
|
||||
// Reverse a list of events
|
||||
func Reversed(in []*gomatrixserverlib.HeaderedEvent) []*gomatrixserverlib.HeaderedEvent {
|
||||
out := make([]*gomatrixserverlib.HeaderedEvent, len(in))
|
||||
|
|
|
|||
47
test/http.go
47
test/http.go
|
|
@ -2,10 +2,15 @@ package test
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
|
|
@ -43,3 +48,45 @@ func NewRequest(t *testing.T, method, path string, opts ...HTTPRequestOpt) *http
|
|||
}
|
||||
return req
|
||||
}
|
||||
|
||||
// ListenAndServe will listen on a random high-numbered port and attach the given router.
|
||||
// Returns the base URL to send requests to. Call `cancel` to shutdown the server, which will block until it has closed.
|
||||
func ListenAndServe(t *testing.T, router http.Handler, withTLS bool) (apiURL string, cancel func()) {
|
||||
listener, err := net.Listen("tcp", ":0")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to listen: %s", err)
|
||||
}
|
||||
port := listener.Addr().(*net.TCPAddr).Port
|
||||
srv := http.Server{}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
srv.Handler = router
|
||||
var err error
|
||||
if withTLS {
|
||||
certFile := filepath.Join(t.TempDir(), "dendrite.cert")
|
||||
keyFile := filepath.Join(t.TempDir(), "dendrite.key")
|
||||
err = NewTLSKey(keyFile, certFile)
|
||||
if err != nil {
|
||||
t.Errorf("failed to make TLS key: %s", err)
|
||||
return
|
||||
}
|
||||
err = srv.ServeTLS(listener, certFile, keyFile)
|
||||
} else {
|
||||
err = srv.Serve(listener)
|
||||
}
|
||||
if err != nil && err != http.ErrServerClosed {
|
||||
t.Logf("Listen failed: %s", err)
|
||||
}
|
||||
}()
|
||||
s := ""
|
||||
if withTLS {
|
||||
s = "s"
|
||||
}
|
||||
return fmt.Sprintf("http%s://localhost:%d", s, port), func() {
|
||||
_ = srv.Shutdown(context.Background())
|
||||
wg.Wait()
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -25,103 +25,19 @@ import (
|
|||
"io/ioutil"
|
||||
"math/big"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"gopkg.in/yaml.v2"
|
||||
)
|
||||
|
||||
const (
|
||||
// ConfigFile is the name of the config file for a server.
|
||||
ConfigFile = "dendrite.yaml"
|
||||
// ServerKeyFile is the name of the file holding the matrix server private key.
|
||||
ServerKeyFile = "server_key.pem"
|
||||
// TLSCertFile is the name of the file holding the TLS certificate used for federation.
|
||||
TLSCertFile = "tls_cert.pem"
|
||||
// TLSKeyFile is the name of the file holding the TLS key used for federation.
|
||||
TLSKeyFile = "tls_key.pem"
|
||||
// MediaDir is the name of the directory used to store media.
|
||||
MediaDir = "media"
|
||||
)
|
||||
|
||||
// MakeConfig makes a config suitable for running integration tests.
|
||||
// Generates new matrix and TLS keys for the server.
|
||||
func MakeConfig(configDir, kafkaURI, database, host string, startPort int) (*config.Dendrite, int, error) {
|
||||
var cfg config.Dendrite
|
||||
cfg.Defaults(true)
|
||||
|
||||
port := startPort
|
||||
assignAddress := func() config.HTTPAddress {
|
||||
result := config.HTTPAddress(fmt.Sprintf("http://%s:%d", host, port))
|
||||
port++
|
||||
return result
|
||||
}
|
||||
|
||||
serverKeyPath := filepath.Join(configDir, ServerKeyFile)
|
||||
tlsCertPath := filepath.Join(configDir, TLSKeyFile)
|
||||
tlsKeyPath := filepath.Join(configDir, TLSCertFile)
|
||||
mediaBasePath := filepath.Join(configDir, MediaDir)
|
||||
|
||||
if err := NewMatrixKey(serverKeyPath); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
if err := NewTLSKey(tlsKeyPath, tlsCertPath); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
cfg.Version = config.Version
|
||||
|
||||
cfg.Global.ServerName = gomatrixserverlib.ServerName(assignAddress())
|
||||
cfg.Global.PrivateKeyPath = config.Path(serverKeyPath)
|
||||
|
||||
cfg.MediaAPI.BasePath = config.Path(mediaBasePath)
|
||||
|
||||
cfg.Global.JetStream.Addresses = []string{kafkaURI}
|
||||
|
||||
// TODO: Use different databases for the different schemas.
|
||||
// Using the same database for every schema currently works because
|
||||
// the table names are globally unique. But we might not want to
|
||||
// rely on that in the future.
|
||||
cfg.AppServiceAPI.Database.ConnectionString = config.DataSource(database)
|
||||
cfg.FederationAPI.Database.ConnectionString = config.DataSource(database)
|
||||
cfg.KeyServer.Database.ConnectionString = config.DataSource(database)
|
||||
cfg.MediaAPI.Database.ConnectionString = config.DataSource(database)
|
||||
cfg.RoomServer.Database.ConnectionString = config.DataSource(database)
|
||||
cfg.SyncAPI.Database.ConnectionString = config.DataSource(database)
|
||||
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(database)
|
||||
|
||||
cfg.AppServiceAPI.InternalAPI.Listen = assignAddress()
|
||||
cfg.FederationAPI.InternalAPI.Listen = assignAddress()
|
||||
cfg.KeyServer.InternalAPI.Listen = assignAddress()
|
||||
cfg.MediaAPI.InternalAPI.Listen = assignAddress()
|
||||
cfg.RoomServer.InternalAPI.Listen = assignAddress()
|
||||
cfg.SyncAPI.InternalAPI.Listen = assignAddress()
|
||||
cfg.UserAPI.InternalAPI.Listen = assignAddress()
|
||||
|
||||
cfg.AppServiceAPI.InternalAPI.Connect = cfg.AppServiceAPI.InternalAPI.Listen
|
||||
cfg.FederationAPI.InternalAPI.Connect = cfg.FederationAPI.InternalAPI.Listen
|
||||
cfg.KeyServer.InternalAPI.Connect = cfg.KeyServer.InternalAPI.Listen
|
||||
cfg.MediaAPI.InternalAPI.Connect = cfg.MediaAPI.InternalAPI.Listen
|
||||
cfg.RoomServer.InternalAPI.Connect = cfg.RoomServer.InternalAPI.Listen
|
||||
cfg.SyncAPI.InternalAPI.Connect = cfg.SyncAPI.InternalAPI.Listen
|
||||
cfg.UserAPI.InternalAPI.Connect = cfg.UserAPI.InternalAPI.Listen
|
||||
|
||||
return &cfg, port, nil
|
||||
}
|
||||
|
||||
// WriteConfig writes the config file to the directory.
|
||||
func WriteConfig(cfg *config.Dendrite, configDir string) error {
|
||||
data, err := yaml.Marshal(cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return ioutil.WriteFile(filepath.Join(configDir, ConfigFile), data, 0666)
|
||||
}
|
||||
|
||||
// NewMatrixKey generates a new ed25519 matrix server key and writes it to a file.
|
||||
func NewMatrixKey(matrixKeyPath string) (err error) {
|
||||
var data [35]byte
|
||||
54
test/room.go
54
test/room.go
|
|
@ -15,7 +15,6 @@
|
|||
package test
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
|
|
@ -35,12 +34,6 @@ var (
|
|||
PresetTrustedPrivateChat Preset = 3
|
||||
|
||||
roomIDCounter = int64(0)
|
||||
|
||||
testKeyID = gomatrixserverlib.KeyID("ed25519:test")
|
||||
testPrivateKey = ed25519.NewKeyFromSeed([]byte{
|
||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
|
||||
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,
|
||||
})
|
||||
)
|
||||
|
||||
type Room struct {
|
||||
|
|
@ -50,6 +43,7 @@ type Room struct {
|
|||
creator *User
|
||||
|
||||
authEvents gomatrixserverlib.AuthEvents
|
||||
currentState map[string]*gomatrixserverlib.HeaderedEvent
|
||||
events []*gomatrixserverlib.HeaderedEvent
|
||||
}
|
||||
|
||||
|
|
@ -57,14 +51,16 @@ type Room struct {
|
|||
func NewRoom(t *testing.T, creator *User, modifiers ...roomModifier) *Room {
|
||||
t.Helper()
|
||||
counter := atomic.AddInt64(&roomIDCounter, 1)
|
||||
|
||||
// set defaults then let roomModifiers override
|
||||
if creator.srvName == "" {
|
||||
t.Fatalf("NewRoom: creator doesn't belong to a server: %+v", *creator)
|
||||
}
|
||||
r := &Room{
|
||||
ID: fmt.Sprintf("!%d:localhost", counter),
|
||||
ID: fmt.Sprintf("!%d:%s", counter, creator.srvName),
|
||||
creator: creator,
|
||||
authEvents: gomatrixserverlib.NewAuthEvents(nil),
|
||||
preset: PresetPublicChat,
|
||||
Version: gomatrixserverlib.RoomVersionV9,
|
||||
currentState: make(map[string]*gomatrixserverlib.HeaderedEvent),
|
||||
}
|
||||
for _, m := range modifiers {
|
||||
m(t, r)
|
||||
|
|
@ -73,6 +69,24 @@ func NewRoom(t *testing.T, creator *User, modifiers ...roomModifier) *Room {
|
|||
return r
|
||||
}
|
||||
|
||||
func (r *Room) MustGetAuthEventRefsForEvent(t *testing.T, needed gomatrixserverlib.StateNeeded) []gomatrixserverlib.EventReference {
|
||||
t.Helper()
|
||||
a, err := needed.AuthEventReferences(&r.authEvents)
|
||||
if err != nil {
|
||||
t.Fatalf("MustGetAuthEvents: %v", err)
|
||||
}
|
||||
return a
|
||||
}
|
||||
|
||||
func (r *Room) ForwardExtremities() []string {
|
||||
if len(r.events) == 0 {
|
||||
return nil
|
||||
}
|
||||
return []string{
|
||||
r.events[len(r.events)-1].EventID(),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Room) insertCreateEvents(t *testing.T) {
|
||||
t.Helper()
|
||||
var joinRule gomatrixserverlib.JoinRuleContent
|
||||
|
|
@ -88,6 +102,7 @@ func (r *Room) insertCreateEvents(t *testing.T) {
|
|||
joinRule.JoinRule = "public"
|
||||
hisVis.HistoryVisibility = "shared"
|
||||
}
|
||||
|
||||
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomCreate, map[string]interface{}{
|
||||
"creator": r.creator.ID,
|
||||
"room_version": r.Version,
|
||||
|
|
@ -112,16 +127,16 @@ func (r *Room) CreateEvent(t *testing.T, creator *User, eventType string, conten
|
|||
}
|
||||
|
||||
if mod.privKey == nil {
|
||||
mod.privKey = testPrivateKey
|
||||
mod.privKey = creator.privKey
|
||||
}
|
||||
if mod.keyID == "" {
|
||||
mod.keyID = testKeyID
|
||||
mod.keyID = creator.keyID
|
||||
}
|
||||
if mod.originServerTS.IsZero() {
|
||||
mod.originServerTS = time.Now()
|
||||
}
|
||||
if mod.origin == "" {
|
||||
mod.origin = gomatrixserverlib.ServerName("localhost")
|
||||
mod.origin = creator.srvName
|
||||
}
|
||||
|
||||
var unsigned gomatrixserverlib.RawJSON
|
||||
|
|
@ -174,13 +189,14 @@ func (r *Room) CreateEvent(t *testing.T, creator *User, eventType string, conten
|
|||
// Add a new event to this room DAG. Not thread-safe.
|
||||
func (r *Room) InsertEvent(t *testing.T, he *gomatrixserverlib.HeaderedEvent) {
|
||||
t.Helper()
|
||||
// Add the event to the list of auth events
|
||||
// Add the event to the list of auth/state events
|
||||
r.events = append(r.events, he)
|
||||
if he.StateKey() != nil {
|
||||
err := r.authEvents.AddEvent(he.Unwrap())
|
||||
if err != nil {
|
||||
t.Fatalf("InsertEvent: failed to add event to auth events: %s", err)
|
||||
}
|
||||
r.currentState[he.Type()+" "+*he.StateKey()] = he
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -188,6 +204,16 @@ func (r *Room) Events() []*gomatrixserverlib.HeaderedEvent {
|
|||
return r.events
|
||||
}
|
||||
|
||||
func (r *Room) CurrentState() []*gomatrixserverlib.HeaderedEvent {
|
||||
events := make([]*gomatrixserverlib.HeaderedEvent, len(r.currentState))
|
||||
i := 0
|
||||
for _, e := range r.currentState {
|
||||
events[i] = e
|
||||
i++
|
||||
}
|
||||
return events
|
||||
}
|
||||
|
||||
func (r *Room) CreateAndInsert(t *testing.T, creator *User, eventType string, content interface{}, mods ...eventModifier) *gomatrixserverlib.HeaderedEvent {
|
||||
t.Helper()
|
||||
he := r.CreateEvent(t, creator, eventType, content, mods...)
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package test
|
||||
package testrig
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
|
@ -24,22 +24,23 @@ import (
|
|||
|
||||
"github.com/matrix-org/dendrite/setup/base"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/test"
|
||||
"github.com/nats-io/nats.go"
|
||||
)
|
||||
|
||||
func CreateBaseDendrite(t *testing.T, dbType DBType) (*base.BaseDendrite, func()) {
|
||||
func CreateBaseDendrite(t *testing.T, dbType test.DBType) (*base.BaseDendrite, func()) {
|
||||
var cfg config.Dendrite
|
||||
cfg.Defaults(false)
|
||||
cfg.Global.JetStream.InMemory = true
|
||||
|
||||
switch dbType {
|
||||
case DBTypePostgres:
|
||||
case test.DBTypePostgres:
|
||||
cfg.Global.Defaults(true) // autogen a signing key
|
||||
cfg.MediaAPI.Defaults(true) // autogen a media path
|
||||
// use a distinct prefix else concurrent postgres/sqlite runs will clash since NATS will use
|
||||
// the file system event with InMemory=true :(
|
||||
cfg.Global.JetStream.TopicPrefix = fmt.Sprintf("Test_%d_", dbType)
|
||||
connStr, close := PrepareDBConnectionString(t, dbType)
|
||||
connStr, close := test.PrepareDBConnectionString(t, dbType)
|
||||
cfg.Global.DatabaseOptions = config.DatabaseOptions{
|
||||
ConnectionString: config.DataSource(connStr),
|
||||
MaxOpenConnections: 10,
|
||||
|
|
@ -47,7 +48,7 @@ func CreateBaseDendrite(t *testing.T, dbType DBType) (*base.BaseDendrite, func()
|
|||
ConnMaxLifetimeSeconds: 60,
|
||||
}
|
||||
return base.NewBaseDendrite(&cfg, "Test", base.DisableMetrics), close
|
||||
case DBTypeSQLite:
|
||||
case test.DBTypeSQLite:
|
||||
cfg.Defaults(true) // sets a sqlite db per component
|
||||
// use a distinct prefix else concurrent postgres/sqlite runs will clash since NATS will use
|
||||
// the file system event with InMemory=true :(
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
package test
|
||||
package testrig
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
52
test/user.go
52
test/user.go
|
|
@ -15,22 +15,64 @@
|
|||
package test
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
var (
|
||||
userIDCounter = int64(0)
|
||||
|
||||
serverName = gomatrixserverlib.ServerName("test")
|
||||
keyID = gomatrixserverlib.KeyID("ed25519:test")
|
||||
privateKey = ed25519.NewKeyFromSeed([]byte{
|
||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
|
||||
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,
|
||||
})
|
||||
|
||||
// private keys that tests can use
|
||||
PrivateKeyA = ed25519.NewKeyFromSeed([]byte{
|
||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
|
||||
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 77,
|
||||
})
|
||||
PrivateKeyB = ed25519.NewKeyFromSeed([]byte{
|
||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
|
||||
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 66,
|
||||
})
|
||||
)
|
||||
|
||||
type User struct {
|
||||
ID string
|
||||
// key ID and private key of the server who has this user, if known.
|
||||
keyID gomatrixserverlib.KeyID
|
||||
privKey ed25519.PrivateKey
|
||||
srvName gomatrixserverlib.ServerName
|
||||
}
|
||||
|
||||
func NewUser() *User {
|
||||
counter := atomic.AddInt64(&userIDCounter, 1)
|
||||
u := &User{
|
||||
ID: fmt.Sprintf("@%d:localhost", counter),
|
||||
type UserOpt func(*User)
|
||||
|
||||
func WithSigningServer(srvName gomatrixserverlib.ServerName, keyID gomatrixserverlib.KeyID, privKey ed25519.PrivateKey) UserOpt {
|
||||
return func(u *User) {
|
||||
u.keyID = keyID
|
||||
u.privKey = privKey
|
||||
u.srvName = srvName
|
||||
}
|
||||
return u
|
||||
}
|
||||
|
||||
func NewUser(t *testing.T, opts ...UserOpt) *User {
|
||||
counter := atomic.AddInt64(&userIDCounter, 1)
|
||||
var u User
|
||||
for _, opt := range opts {
|
||||
opt(&u)
|
||||
}
|
||||
if u.keyID == "" || u.srvName == "" || u.privKey == nil {
|
||||
t.Logf("NewUser: missing signing server credentials; using default.")
|
||||
WithSigningServer(serverName, keyID, privateKey)(&u)
|
||||
}
|
||||
u.ID = fmt.Sprintf("@%d:%s", counter, u.srvName)
|
||||
t.Logf("NewUser: created user %s", u.ID)
|
||||
return &u
|
||||
}
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ func Test_AccountData(t *testing.T) {
|
|||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := mustCreateDatabase(t, dbType)
|
||||
defer close()
|
||||
alice := test.NewUser()
|
||||
alice := test.NewUser(t)
|
||||
localpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
|
|
@ -74,7 +74,7 @@ func Test_Accounts(t *testing.T) {
|
|||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := mustCreateDatabase(t, dbType)
|
||||
defer close()
|
||||
alice := test.NewUser()
|
||||
alice := test.NewUser(t)
|
||||
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
|
|
@ -128,7 +128,7 @@ func Test_Accounts(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_Devices(t *testing.T) {
|
||||
alice := test.NewUser()
|
||||
alice := test.NewUser(t)
|
||||
localpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
assert.NoError(t, err)
|
||||
deviceID := util.RandomString(8)
|
||||
|
|
@ -212,7 +212,7 @@ func Test_Devices(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_KeyBackup(t *testing.T) {
|
||||
alice := test.NewUser()
|
||||
alice := test.NewUser(t)
|
||||
room := test.NewRoom(t, alice)
|
||||
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
|
|
@ -291,7 +291,7 @@ func Test_KeyBackup(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_LoginToken(t *testing.T) {
|
||||
alice := test.NewUser()
|
||||
alice := test.NewUser(t)
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := mustCreateDatabase(t, dbType)
|
||||
defer close()
|
||||
|
|
@ -321,7 +321,7 @@ func Test_LoginToken(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_OpenID(t *testing.T) {
|
||||
alice := test.NewUser()
|
||||
alice := test.NewUser(t)
|
||||
token := util.RandomString(24)
|
||||
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
|
|
@ -341,7 +341,7 @@ func Test_OpenID(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_Profile(t *testing.T) {
|
||||
alice := test.NewUser()
|
||||
alice := test.NewUser(t)
|
||||
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
|
|
@ -379,7 +379,7 @@ func Test_Profile(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_Pusher(t *testing.T) {
|
||||
alice := test.NewUser()
|
||||
alice := test.NewUser(t)
|
||||
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
|
|
@ -430,7 +430,7 @@ func Test_Pusher(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_ThreePID(t *testing.T) {
|
||||
alice := test.NewUser()
|
||||
alice := test.NewUser(t)
|
||||
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
|
|
@ -467,7 +467,7 @@ func Test_ThreePID(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_Notification(t *testing.T) {
|
||||
alice := test.NewUser()
|
||||
alice := test.NewUser(t)
|
||||
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
assert.NoError(t, err)
|
||||
room := test.NewRoom(t, alice)
|
||||
|
|
|
|||
|
|
@ -24,7 +24,6 @@ import (
|
|||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/matrix-org/dendrite/internal/httputil"
|
||||
internalTest "github.com/matrix-org/dendrite/internal/test"
|
||||
"github.com/matrix-org/dendrite/test"
|
||||
"github.com/matrix-org/dendrite/userapi"
|
||||
"github.com/matrix-org/dendrite/userapi/inthttp"
|
||||
|
|
@ -135,7 +134,7 @@ func TestQueryProfile(t *testing.T) {
|
|||
t.Run("HTTP API", func(t *testing.T) {
|
||||
router := mux.NewRouter().PathPrefix(httputil.InternalPathPrefix).Subrouter()
|
||||
userapi.AddInternalRoutes(router, userAPI)
|
||||
apiURL, cancel := internalTest.ListenAndServe(t, router, false)
|
||||
apiURL, cancel := test.ListenAndServe(t, router, false)
|
||||
defer cancel()
|
||||
httpAPI, err := inthttp.NewUserAPIClient(apiURL, &http.Client{})
|
||||
if err != nil {
|
||||
|
|
|
|||
Loading…
Reference in a new issue