diff --git a/federationapi/storage/storage_test.go b/federationapi/storage/storage_test.go index f7408fa9f..564d317bc 100644 --- a/federationapi/storage/storage_test.go +++ b/federationapi/storage/storage_test.go @@ -2,10 +2,12 @@ package storage_test import ( "context" + "reflect" "testing" "time" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" "github.com/stretchr/testify/assert" "github.com/matrix-org/dendrite/federationapi/storage" @@ -80,3 +82,167 @@ func TestExpireEDUs(t *testing.T) { assert.Equal(t, 2, len(data)) }) } + +func TestOutboundPeeking(t *testing.T) { + alice := test.NewUser(t) + room := test.NewRoom(t, alice) + _, serverName, _ := gomatrixserverlib.SplitID('@', alice.ID) + ctx := context.Background() + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, closeDB := mustCreateFederationDatabase(t, dbType) + defer closeDB() + peekID := util.RandomString(8) + var renewalInterval int64 = 1000 + + // Add outbound peek + if err := db.AddOutboundPeek(ctx, serverName, room.ID, peekID, renewalInterval); err != nil { + t.Fatal(err) + } + + // select the newly inserted peek + outboundPeek1, err := db.GetOutboundPeek(ctx, serverName, room.ID, peekID) + if err != nil { + t.Fatal(err) + } + + // Assert fields are set as expected + if outboundPeek1.PeekID != peekID { + t.Fatalf("unexpected outbound peek ID: %s, want %s", outboundPeek1.PeekID, peekID) + } + if outboundPeek1.RoomID != room.ID { + t.Fatalf("unexpected outbound peek room ID: %s, want %s", outboundPeek1.RoomID, peekID) + } + if outboundPeek1.ServerName != serverName { + t.Fatalf("unexpected outbound peek servername: %s, want %s", outboundPeek1.ServerName, serverName) + } + if outboundPeek1.RenewalInterval != renewalInterval { + t.Fatalf("unexpected outbound peek renewal interval: %d, want %d", outboundPeek1.RenewalInterval, renewalInterval) + } + // Renew the peek + if err := db.RenewOutboundPeek(ctx, serverName, room.ID, peekID, 2000); err != nil { + t.Fatal(err) + } + + // verify the values changed + outboundPeek2, err := db.GetOutboundPeek(ctx, serverName, room.ID, peekID) + if err != nil { + t.Fatal(err) + } + if reflect.DeepEqual(outboundPeek1, outboundPeek2) { + t.Fatal("expected a change peek, but they are the same") + } + if outboundPeek1.ServerName != outboundPeek2.ServerName { + t.Fatalf("unexpected servername change: %s -> %s", outboundPeek1.ServerName, outboundPeek2.ServerName) + } + if outboundPeek1.RoomID != outboundPeek2.RoomID { + t.Fatalf("unexpected roomID change: %s -> %s", outboundPeek1.RoomID, outboundPeek2.RoomID) + } + + // insert some peeks + peekIDs := []string{peekID} + for i := 0; i < 5; i++ { + peekID = util.RandomString(8) + if err := db.AddOutboundPeek(ctx, serverName, room.ID, peekID, 1000); err != nil { + t.Fatal(err) + } + peekIDs = append(peekIDs, peekID) + } + + // Now select them + outboundPeeks, err := db.GetOutboundPeeks(ctx, room.ID) + if err != nil { + t.Fatal(err) + } + if len(outboundPeeks) != len(peekIDs) { + t.Fatalf("inserted %d peeks, selected %d", len(peekIDs), len(outboundPeeks)) + } + for i := range outboundPeeks { + if outboundPeeks[i].PeekID != peekIDs[i] { + t.Fatalf("unexpected peek ID: %s, want %s", outboundPeeks[i].PeekID, peekIDs[i]) + } + } + }) +} + +func TestInboundPeeking(t *testing.T) { + alice := test.NewUser(t) + room := test.NewRoom(t, alice) + _, serverName, _ := gomatrixserverlib.SplitID('@', alice.ID) + ctx := context.Background() + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, closeDB := mustCreateFederationDatabase(t, dbType) + defer closeDB() + peekID := util.RandomString(8) + var renewalInterval int64 = 1000 + + // Add inbound peek + if err := db.AddInboundPeek(ctx, serverName, room.ID, peekID, renewalInterval); err != nil { + t.Fatal(err) + } + + // select the newly inserted peek + inboundPeek1, err := db.GetInboundPeek(ctx, serverName, room.ID, peekID) + if err != nil { + t.Fatal(err) + } + + // Assert fields are set as expected + if inboundPeek1.PeekID != peekID { + t.Fatalf("unexpected inbound peek ID: %s, want %s", inboundPeek1.PeekID, peekID) + } + if inboundPeek1.RoomID != room.ID { + t.Fatalf("unexpected inbound peek room ID: %s, want %s", inboundPeek1.RoomID, peekID) + } + if inboundPeek1.ServerName != serverName { + t.Fatalf("unexpected inbound peek servername: %s, want %s", inboundPeek1.ServerName, serverName) + } + if inboundPeek1.RenewalInterval != renewalInterval { + t.Fatalf("unexpected inbound peek renewal interval: %d, want %d", inboundPeek1.RenewalInterval, renewalInterval) + } + // Renew the peek + if err := db.RenewInboundPeek(ctx, serverName, room.ID, peekID, 2000); err != nil { + t.Fatal(err) + } + + // verify the values changed + inboundPeek2, err := db.GetInboundPeek(ctx, serverName, room.ID, peekID) + if err != nil { + t.Fatal(err) + } + if reflect.DeepEqual(inboundPeek1, inboundPeek2) { + t.Fatal("expected a change peek, but they are the same") + } + if inboundPeek1.ServerName != inboundPeek2.ServerName { + t.Fatalf("unexpected servername change: %s -> %s", inboundPeek1.ServerName, inboundPeek2.ServerName) + } + if inboundPeek1.RoomID != inboundPeek2.RoomID { + t.Fatalf("unexpected roomID change: %s -> %s", inboundPeek1.RoomID, inboundPeek2.RoomID) + } + + // insert some peeks + peekIDs := []string{peekID} + for i := 0; i < 5; i++ { + peekID = util.RandomString(8) + if err := db.AddInboundPeek(ctx, serverName, room.ID, peekID, 1000); err != nil { + t.Fatal(err) + } + peekIDs = append(peekIDs, peekID) + } + + // Now select them + inboundPeeks, err := db.GetInboundPeeks(ctx, room.ID) + if err != nil { + t.Fatal(err) + } + if len(inboundPeeks) != len(peekIDs) { + t.Fatalf("inserted %d peeks, selected %d", len(peekIDs), len(inboundPeeks)) + } + for i := range inboundPeeks { + if inboundPeeks[i].PeekID != peekIDs[i] { + t.Fatalf("unexpected peek ID: %s, want %s", inboundPeeks[i].PeekID, peekIDs[i]) + } + } + }) +}