From cb6e9c068495945663ad1630fab4cdc679b5a589 Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Fri, 10 Mar 2023 16:10:46 +0100 Subject: [PATCH] Cache eventTypeNIDs and eventStateKeyNIDs when using GetStateEvent --- roomserver/roomserver_test.go | 36 ++++++++++ roomserver/storage/shared/storage.go | 6 +- roomserver/storage/tables/interface_test.go | 75 +++++++++++++++++++++ 3 files changed, 115 insertions(+), 2 deletions(-) create mode 100644 roomserver/storage/tables/interface_test.go diff --git a/roomserver/roomserver_test.go b/roomserver/roomserver_test.go index a3ca5909e..ab0bb188e 100644 --- a/roomserver/roomserver_test.go +++ b/roomserver/roomserver_test.go @@ -7,6 +7,7 @@ import ( "testing" "time" + "github.com/matrix-org/dendrite/roomserver/acls" "github.com/stretchr/testify/assert" "github.com/matrix-org/dendrite/roomserver/state" @@ -571,3 +572,38 @@ func TestRedaction(t *testing.T) { } }) } + +func TestNewServerACLs(t *testing.T) { + alice := test.NewUser(t) + roomWithACL := test.NewRoom(t, alice) + + roomWithACL.CreateAndInsert(t, alice, "m.room.server_acl", acls.ServerACL{ + Allowed: []string{"*"}, + Denied: []string{"localhost"}, + AllowIPLiterals: false, + }, test.WithStateKey("")) + + roomWithoutACL := test.NewRoom(t, alice) + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, db, closeBase := mustCreateDatabase(t, dbType) + defer closeBase() + + // start JetStream listeners + rsAPI := roomserver.NewInternalAPI(base) + rsAPI.SetFederationAPI(nil, nil) + + // let the RS create the events + err := api.SendEvents(context.Background(), rsAPI, api.KindNew, roomWithACL.Events(), "test", "test", "test", nil, false) + assert.NoError(t, err) + err = api.SendEvents(context.Background(), rsAPI, api.KindNew, roomWithoutACL.Events(), "test", "test", "test", nil, false) + assert.NoError(t, err) + + // create new server ACLs and verify server is banned/not banned + serverACLs := acls.NewServerACLs(db) + banned := serverACLs.IsServerBannedFromRoom("localhost", roomWithACL.ID) + assert.Equal(t, true, banned) + banned = serverACLs.IsServerBannedFromRoom("localhost", roomWithoutACL.ID) + assert.Equal(t, false, banned) + }) +} diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 78bda95e5..cad93d643 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -1154,7 +1154,7 @@ func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey s if roomInfo.IsStub() { return nil, nil } - eventTypeNID, err := d.EventTypesTable.SelectEventTypeNID(ctx, nil, evType) + eventTypeNIDMap, err := d.eventTypeNIDs(ctx, nil, []string{evType}) if err == sql.ErrNoRows { // No rooms have an event of this type, otherwise we'd have an event type NID return nil, nil @@ -1162,7 +1162,8 @@ func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey s if err != nil { return nil, err } - stateKeyNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, nil, stateKey) + eventTypeNID := eventTypeNIDMap[evType] + stateKeyNIDMap, err := d.eventStateKeyNIDs(ctx, nil, []string{stateKey}) if err == sql.ErrNoRows { // No rooms have a state event with this state key, otherwise we'd have an state key NID return nil, nil @@ -1170,6 +1171,7 @@ func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey s if err != nil { return nil, err } + stateKeyNID := stateKeyNIDMap[stateKey] entries, err := d.loadStateAtSnapshot(ctx, roomInfo.StateSnapshotNID()) if err != nil { return nil, err diff --git a/roomserver/storage/tables/interface_test.go b/roomserver/storage/tables/interface_test.go new file mode 100644 index 000000000..16f957b54 --- /dev/null +++ b/roomserver/storage/tables/interface_test.go @@ -0,0 +1,75 @@ +package tables + +import ( + "testing" + + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" +) + +func TestExtractContentValue(t *testing.T) { + alice := test.NewUser(t) + room := test.NewRoom(t, alice) + + tests := []struct { + name string + event *gomatrixserverlib.HeaderedEvent + want string + }{ + { + name: "returns creator ID for create events", + event: room.Events()[0], + want: alice.ID, + }, + { + name: "returns the alias for canonical alias events", + event: room.CreateEvent(t, alice, gomatrixserverlib.MRoomCanonicalAlias, map[string]string{"alias": "#test:test"}), + want: "#test:test", + }, + { + name: "returns the history_visibility for history visibility events", + event: room.CreateEvent(t, alice, gomatrixserverlib.MRoomHistoryVisibility, map[string]string{"history_visibility": "shared"}), + want: "shared", + }, + { + name: "returns the join rules for join_rules events", + event: room.CreateEvent(t, alice, gomatrixserverlib.MRoomJoinRules, map[string]string{"join_rule": "public"}), + want: "public", + }, + { + name: "returns the membership for room_member events", + event: room.CreateEvent(t, alice, gomatrixserverlib.MRoomMember, map[string]string{"membership": "join"}, test.WithStateKey(alice.ID)), + want: "join", + }, + { + name: "returns the room name for room_name events", + event: room.CreateEvent(t, alice, gomatrixserverlib.MRoomName, map[string]string{"name": "testing"}, test.WithStateKey(alice.ID)), + want: "testing", + }, + { + name: "returns the room avatar for avatar events", + event: room.CreateEvent(t, alice, gomatrixserverlib.MRoomAvatar, map[string]string{"url": "mxc://testing"}, test.WithStateKey(alice.ID)), + want: "mxc://testing", + }, + { + name: "returns the room topic for topic events", + event: room.CreateEvent(t, alice, gomatrixserverlib.MRoomTopic, map[string]string{"topic": "testing"}, test.WithStateKey(alice.ID)), + want: "testing", + }, + { + name: "returns guest_access for guest access events", + event: room.CreateEvent(t, alice, "m.room.guest_access", map[string]string{"guest_access": "forbidden"}, test.WithStateKey(alice.ID)), + want: "forbidden", + }, + { + name: "returns empty string if key can't be found or unknown event", + event: room.CreateEvent(t, alice, "idontexist", nil), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equalf(t, tt.want, ExtractContentValue(tt.event), "ExtractContentValue(%v)", tt.event) + }) + } +}