diff --git a/cmd/dendrite-client-api-server/main.go b/cmd/dendrite-client-api-server/main.go index 33c3eae30..f46dae502 100644 --- a/cmd/dendrite-client-api-server/main.go +++ b/cmd/dendrite-client-api-server/main.go @@ -35,10 +35,11 @@ func main() { fsAPI := base.FederationSenderHTTPClient() eduInputAPI := base.EDUServerClient() userAPI := base.UserAPIClient() + stateAPI := base.CurrentStateAPIClient() clientapi.AddPublicRoutes( base.PublicAPIMux, base.Cfg, base.KafkaProducer, deviceDB, accountDB, federation, - rsAPI, eduInputAPI, asQuery, transactions.New(), fsAPI, userAPI, + rsAPI, eduInputAPI, asQuery, stateAPI, transactions.New(), fsAPI, userAPI, ) base.SetupAndServeHTTP(string(base.Cfg.Bind.ClientAPI), string(base.Cfg.Listen.ClientAPI)) diff --git a/cmd/dendrite-demo-libp2p/main.go b/cmd/dendrite-demo-libp2p/main.go index 602f923e2..b7e86b77c 100644 --- a/cmd/dendrite-demo-libp2p/main.go +++ b/cmd/dendrite-demo-libp2p/main.go @@ -167,7 +167,7 @@ func main() { if err != nil { logrus.WithError(err).Panicf("failed to connect to public rooms db") } - stateAPI := currentstateserver.NewInternalAPI(base.Cfg, base.KafkaConsumer) + stateAPI := currentstateserver.NewInternalAPI(base.Base.Cfg, base.Base.KafkaConsumer) monolith := setup.Monolith{ Config: base.Base.Cfg, diff --git a/currentstateserver/currentstateserver_test.go b/currentstateserver/currentstateserver_test.go index 95ca609b4..a0627fea7 100644 --- a/currentstateserver/currentstateserver_test.go +++ b/currentstateserver/currentstateserver_test.go @@ -136,8 +136,8 @@ func TestQueryCurrentState(t *testing.T) { }, }, wantRes: api.QueryCurrentStateResponse{ - StateEvents: map[gomatrixserverlib.StateKeyTuple]gomatrixserverlib.HeaderedEvent{ - plTuple: plEvent, + StateEvents: map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent{ + plTuple: &plEvent, }, }, }, diff --git a/currentstateserver/storage/postgres/current_room_state_table.go b/currentstateserver/storage/postgres/current_room_state_table.go index 0764b3f69..95621913b 100644 --- a/currentstateserver/storage/postgres/current_room_state_table.go +++ b/currentstateserver/storage/postgres/current_room_state_table.go @@ -18,6 +18,7 @@ import ( "context" "database/sql" "encoding/json" + "strconv" "github.com/lib/pq" "github.com/matrix-org/dendrite/currentstateserver/storage/tables" @@ -26,7 +27,9 @@ import ( "github.com/matrix-org/gomatrixserverlib" ) -const currentRoomStateSchema = ` +var leaveEnum = strconv.Itoa(tables.MembershipToEnum["leave"]) + +var currentRoomStateSchema = ` -- Stores the current room state for every room. CREATE TABLE IF NOT EXISTS currentstate_current_room_state ( -- The 'room_id' key for the state event. @@ -41,16 +44,16 @@ CREATE TABLE IF NOT EXISTS currentstate_current_room_state ( state_key TEXT NOT NULL, -- The JSON for the event. Stored as TEXT because this should be valid UTF-8. headered_event_json TEXT NOT NULL, - -- The 'content.membership' value if this event is an m.room.member event. For other - -- events, this will be NULL. - membership TEXT, + -- The 'content.membership' enum value if this event is an m.room.member event. + membership SMALLINT NOT NULL DEFAULT 0, -- Clobber based on 3-uple of room_id, type and state_key CONSTRAINT currentstate_current_room_state_unique UNIQUE (room_id, type, state_key) ); -- for event deletion CREATE UNIQUE INDEX IF NOT EXISTS currentstate_event_id_idx ON currentstate_current_room_state(event_id, room_id, type, sender); -- for querying membership states of users -CREATE INDEX IF NOT EXISTS currentstate_membership_idx ON currentstate_current_room_state(type, state_key, membership) WHERE membership IS NOT NULL AND membership != 'leave'; +CREATE INDEX IF NOT EXISTS currentstate_membership_idx ON currentstate_current_room_state(type, state_key, membership) +WHERE membership IS NOT NULL AND membership != ` + leaveEnum + `; ` const upsertRoomStateSQL = "" + @@ -108,10 +111,10 @@ func (s *currentRoomStateStatements) SelectRoomIDsWithMembership( ctx context.Context, txn *sql.Tx, userID string, - membership string, + membershipEnum int, ) ([]string, error) { stmt := sqlutil.TxStmt(txn, s.selectRoomIDsWithMembershipStmt) - rows, err := stmt.QueryContext(ctx, userID, membership) + rows, err := stmt.QueryContext(ctx, userID, membershipEnum) if err != nil { return nil, err } @@ -138,7 +141,7 @@ func (s *currentRoomStateStatements) DeleteRoomStateByEventID( func (s *currentRoomStateStatements) UpsertRoomState( ctx context.Context, txn *sql.Tx, - event gomatrixserverlib.HeaderedEvent, membership *string, + event gomatrixserverlib.HeaderedEvent, membershipEnum int, ) error { headeredJSON, err := json.Marshal(event) if err != nil { @@ -155,7 +158,7 @@ func (s *currentRoomStateStatements) UpsertRoomState( event.Sender(), *event.StateKey(), headeredJSON, - membership, + membershipEnum, ) return err } diff --git a/currentstateserver/storage/shared/storage.go b/currentstateserver/storage/shared/storage.go index 581380e4d..d78b3e0ed 100644 --- a/currentstateserver/storage/shared/storage.go +++ b/currentstateserver/storage/shared/storage.go @@ -17,6 +17,7 @@ package shared import ( "context" "database/sql" + "fmt" "github.com/matrix-org/dendrite/currentstateserver/storage/tables" "github.com/matrix-org/dendrite/internal/sqlutil" @@ -47,16 +48,20 @@ func (d *Database) StoreStateEvents(ctx context.Context, addStateEvents []gomatr // ignore non state events continue } - var membership *string + var membershipEnum int if event.Type() == "m.room.member" { - value, err := event.Membership() + membership, err := event.Membership() if err != nil { return err } - membership = &value + enum, ok := tables.MembershipToEnum[membership] + if !ok { + return fmt.Errorf("unknown membership: %s", membership) + } + membershipEnum = enum } - if err := d.CurrentRoomState.UpsertRoomState(ctx, txn, event, membership); err != nil { + if err := d.CurrentRoomState.UpsertRoomState(ctx, txn, event, membershipEnum); err != nil { return err } } @@ -65,5 +70,9 @@ func (d *Database) StoreStateEvents(ctx context.Context, addStateEvents []gomatr } func (d *Database) GetRoomsByMembership(ctx context.Context, userID, membership string) ([]string, error) { - return d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, nil, userID, membership) + enum, ok := tables.MembershipToEnum[membership] + if !ok { + return nil, fmt.Errorf("unknown membership: %s", membership) + } + return d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, nil, userID, enum) } diff --git a/currentstateserver/storage/sqlite3/current_room_state_table.go b/currentstateserver/storage/sqlite3/current_room_state_table.go index c18193276..2e2b0e423 100644 --- a/currentstateserver/storage/sqlite3/current_room_state_table.go +++ b/currentstateserver/storage/sqlite3/current_room_state_table.go @@ -35,7 +35,7 @@ CREATE TABLE IF NOT EXISTS currentstate_current_room_state ( sender TEXT NOT NULL, state_key TEXT NOT NULL, headered_event_json TEXT NOT NULL, - membership TEXT, + membership INTEGER NOT NULL DEFAULT 0, UNIQUE (room_id, type, state_key) ); -- for event deletion @@ -100,10 +100,10 @@ func (s *currentRoomStateStatements) SelectRoomIDsWithMembership( ctx context.Context, txn *sql.Tx, userID string, - membership string, // nolint: unparam + membershipEnum int, ) ([]string, error) { stmt := sqlutil.TxStmt(txn, s.selectRoomIDsWithMembershipStmt) - rows, err := stmt.QueryContext(ctx, userID, membership) + rows, err := stmt.QueryContext(ctx, userID, membershipEnum) if err != nil { return nil, err } @@ -130,7 +130,7 @@ func (s *currentRoomStateStatements) DeleteRoomStateByEventID( func (s *currentRoomStateStatements) UpsertRoomState( ctx context.Context, txn *sql.Tx, - event gomatrixserverlib.HeaderedEvent, membership *string, + event gomatrixserverlib.HeaderedEvent, membershipEnum int, ) error { headeredJSON, err := json.Marshal(event) if err != nil { @@ -147,7 +147,7 @@ func (s *currentRoomStateStatements) UpsertRoomState( event.Sender(), *event.StateKey(), headeredJSON, - membership, + membershipEnum, ) return err } diff --git a/currentstateserver/storage/tables/interface.go b/currentstateserver/storage/tables/interface.go index d2e560a21..f2c8b14ed 100644 --- a/currentstateserver/storage/tables/interface.go +++ b/currentstateserver/storage/tables/interface.go @@ -21,11 +21,24 @@ import ( "github.com/matrix-org/gomatrixserverlib" ) +var MembershipToEnum = map[string]int{ + gomatrixserverlib.Invite: 1, + gomatrixserverlib.Join: 2, + gomatrixserverlib.Leave: 3, + gomatrixserverlib.Ban: 4, +} +var EnumToMembership = map[int]string{ + 1: gomatrixserverlib.Invite, + 2: gomatrixserverlib.Join, + 3: gomatrixserverlib.Leave, + 4: gomatrixserverlib.Ban, +} + type CurrentRoomState interface { SelectStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) SelectEventsWithEventIDs(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]gomatrixserverlib.HeaderedEvent, error) - UpsertRoomState(ctx context.Context, txn *sql.Tx, event gomatrixserverlib.HeaderedEvent, membership *string) error + UpsertRoomState(ctx context.Context, txn *sql.Tx, event gomatrixserverlib.HeaderedEvent, membershipEnum int) error DeleteRoomStateByEventID(ctx context.Context, txn *sql.Tx, eventID string) error // SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state. - SelectRoomIDsWithMembership(ctx context.Context, txn *sql.Tx, userID string, membership string) ([]string, error) + SelectRoomIDsWithMembership(ctx context.Context, txn *sql.Tx, userID string, membershipEnum int) ([]string, error) } diff --git a/internal/setup/base.go b/internal/setup/base.go index 66424a609..ddf8e0fad 100644 --- a/internal/setup/base.go +++ b/internal/setup/base.go @@ -22,6 +22,7 @@ import ( "net/url" "time" + currentstateAPI "github.com/matrix-org/dendrite/currentstateserver/api" "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/sqlutil" @@ -37,6 +38,7 @@ import ( appserviceAPI "github.com/matrix-org/dendrite/appservice/api" asinthttp "github.com/matrix-org/dendrite/appservice/inthttp" + currentstateinthttp "github.com/matrix-org/dendrite/currentstateserver/inthttp" eduServerAPI "github.com/matrix-org/dendrite/eduserver/api" eduinthttp "github.com/matrix-org/dendrite/eduserver/inthttp" federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api" @@ -171,6 +173,15 @@ func (b *BaseDendrite) UserAPIClient() userapi.UserInternalAPI { return userAPI } +// CurrentStateAPIClient returns CurrentStateInternalAPI for hitting the currentstateserver over HTTP. +func (b *BaseDendrite) CurrentStateAPIClient() currentstateAPI.CurrentStateInternalAPI { + stateAPI, err := currentstateinthttp.NewCurrentStateAPIClient(b.Cfg.CurrentStateAPIURL(), b.httpClient) + if err != nil { + logrus.WithError(err).Panic("UserAPIClient failed", b.httpClient) + } + return stateAPI +} + // EDUServerClient returns EDUServerInputAPI for hitting the EDU server over HTTP func (b *BaseDendrite) EDUServerClient() eduServerAPI.EDUServerInputAPI { e, err := eduinthttp.NewEDUServerClient(b.Cfg.EDUServerURL(), b.httpClient)