From e82090e277bdadb1d75180d978e9f167840ca4a6 Mon Sep 17 00:00:00 2001 From: Kegsay Date: Thu, 9 Mar 2017 11:47:06 +0000 Subject: [PATCH 1/2] Update gomatrixserverlib dep and add basic /createRoom validation (#31) --- .../dendrite/clientapi/common/common.go | 8 +- .../dendrite/clientapi/writers/createroom.go | 86 ++++++++ vendor/manifest | 4 +- .../matrix-org/gomatrixserverlib/event.go | 15 ++ .../matrix-org/gomatrixserverlib/eventauth.go | 184 +++++++++--------- .../gomatrixserverlib/eventauth_test.go | 53 ++++- .../gomatrixserverlib/eventcontent.go | 2 +- 7 files changed, 250 insertions(+), 102 deletions(-) diff --git a/src/github.com/matrix-org/dendrite/clientapi/common/common.go b/src/github.com/matrix-org/dendrite/clientapi/common/common.go index 72ccfb066..cbc94a6f5 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/common/common.go +++ b/src/github.com/matrix-org/dendrite/clientapi/common/common.go @@ -2,9 +2,10 @@ package common import ( "encoding/json" + "net/http" + "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/util" - "net/http" ) // UnmarshalJSONRequest into the given interface pointer. Returns an error JSON response if @@ -12,9 +13,12 @@ import ( func UnmarshalJSONRequest(req *http.Request, iface interface{}) *util.JSONResponse { defer req.Body.Close() if err := json.NewDecoder(req.Body).Decode(iface); err != nil { + // TODO: We may want to suppress the Error() return in production? It's useful when + // debugging because an error will be produced for both invalid/malformed JSON AND + // valid JSON with incorrect types for values. return &util.JSONResponse{ Code: 400, - JSON: jsonerror.NotJSON("The request body was not JSON"), + JSON: jsonerror.BadJSON("The request body could not be decoded into valid JSON. " + err.Error()), } } return nil diff --git a/src/github.com/matrix-org/dendrite/clientapi/writers/createroom.go b/src/github.com/matrix-org/dendrite/clientapi/writers/createroom.go index 945eae7b1..1dde64e74 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/writers/createroom.go +++ b/src/github.com/matrix-org/dendrite/clientapi/writers/createroom.go @@ -2,11 +2,14 @@ package writers import ( "encoding/json" + "fmt" "net/http" + "strings" log "github.com/Sirupsen/logrus" "github.com/matrix-org/dendrite/clientapi/auth" "github.com/matrix-org/dendrite/clientapi/common" + "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/util" ) @@ -22,8 +25,56 @@ type createRoomRequest struct { RoomAliasName string `json:"room_alias_name"` } +func (r createRoomRequest) Validate() *util.JSONResponse { + whitespace := "\t\n\x0b\x0c\r " // https://docs.python.org/2/library/string.html#string.whitespace + // https://github.com/matrix-org/synapse/blob/v0.19.2/synapse/handlers/room.py#L81 + // Synapse doesn't check for ':' but we will else it will break parsers badly which split things into 2 segments. + if strings.ContainsAny(r.RoomAliasName, whitespace+":") { + return &util.JSONResponse{ + Code: 400, + JSON: jsonerror.BadJSON("room_alias_name cannot contain whitespace"), + } + } + for _, userID := range r.Invite { + // TODO: We should put user ID parsing code into gomatrixserverlib and use that instead + // (see https://github.com/matrix-org/gomatrixserverlib/blob/3394e7c7003312043208aa73727d2256eea3d1f6/eventcontent.go#L347 ) + // It should be a struct (with pointers into a single string to avoid copying) and + // we should update all refs to use UserID types rather than strings. + // https://github.com/matrix-org/synapse/blob/v0.19.2/synapse/types.py#L92 + if len(userID) == 0 || userID[0] != '@' { + return &util.JSONResponse{ + Code: 400, + JSON: jsonerror.BadJSON("user id must start with '@'"), + } + } + parts := strings.SplitN(userID[1:], ":", 2) + if len(parts) != 2 { + return &util.JSONResponse{ + Code: 400, + JSON: jsonerror.BadJSON("user id must be in the form @localpart:domain"), + } + } + } + return nil +} + +// https://matrix.org/docs/spec/client_server/r0.2.0.html#post-matrix-client-r0-createroom +type createRoomResponse struct { + RoomID string `json:"room_id"` + RoomAlias string `json:"room_alias,omitempty"` // in synapse not spec +} + // CreateRoom implements /createRoom func CreateRoom(req *http.Request) util.JSONResponse { + serverName := "localhost" + // TODO: Check room ID doesn't clash with an existing one, and we + // probably shouldn't be using pseudo-random strings, maybe GUIDs? + roomID := fmt.Sprintf("!%s:%s", util.RandomString(16), serverName) + return createRoom(req, roomID) +} + +// createRoom implements /createRoom +func createRoom(req *http.Request, roomID string) util.JSONResponse { logger := util.GetLogger(req.Context()) userID, resErr := auth.VerifyAccessToken(req) if resErr != nil { @@ -34,9 +85,44 @@ func CreateRoom(req *http.Request) util.JSONResponse { if resErr != nil { return *resErr } + // TODO: apply rate-limit + + if resErr = r.Validate(); resErr != nil { + return *resErr + } + + // TODO: visibility/presets/raw initial state/creation content + + // TODO: Create room alias association logger.WithFields(log.Fields{ "userID": userID, + "roomID": roomID, }).Info("Creating room") + + // send events into the room in order of: + // 1- m.room.create + // 2- room creator join member + // 3- m.room.power_levels + // 4- m.room.canonical_alias (opt) TODO + // 5- m.room.join_rules + // 6- m.room.history_visibility + // 7- m.room.guest_access (opt) TODO + // 8- other initial state items TODO + // 9- m.room.name (opt) + // 10- m.room.topic (opt) + // 11- invite events (opt) - with is_direct flag if applicable TODO + // 12- 3pid invite events (opt) TODO + // 13- m.room.aliases event for HS (if alias specified) TODO + // This differs from Synapse slightly. Synapse would vary the ordering of 3-7 + // depending on if those events were in "initial_state" or not. This made it + // harder to reason about, hence sticking to a strict static ordering. + + // f.e event: + // - validate required keys/types (EventValidator in synapse) + // - set additional keys (displayname/avatar_url for m.room.member) + // - set token(?) and txn id + // - then https://github.com/matrix-org/synapse/blob/v0.19.2/synapse/handlers/message.py#L419 + return util.MessageResponse(404, "Not implemented yet") } diff --git a/vendor/manifest b/vendor/manifest index 26f8a5571..fe5d5bbc4 100644 --- a/vendor/manifest +++ b/vendor/manifest @@ -92,7 +92,7 @@ { "importpath": "github.com/matrix-org/gomatrixserverlib", "repository": "https://github.com/matrix-org/gomatrixserverlib", - "revision": "48ee56a33d195dc412dd919a0e81af70c9aaf4a3", + "revision": "ce2ae9c5812346444b0ca75d57834794cde03fb7", "branch": "master" }, { @@ -206,4 +206,4 @@ "branch": "master" } ] -} \ No newline at end of file +} diff --git a/vendor/src/github.com/matrix-org/gomatrixserverlib/event.go b/vendor/src/github.com/matrix-org/gomatrixserverlib/event.go index f7595c796..e1b6ee7d8 100644 --- a/vendor/src/github.com/matrix-org/gomatrixserverlib/event.go +++ b/vendor/src/github.com/matrix-org/gomatrixserverlib/event.go @@ -326,6 +326,21 @@ func (e Event) Depth() int64 { return e.fields.Depth } +// UnmarshalJSON implements json.Unmarshaller assuming the Event is from an untrusted source. +// This will cause more checks than might be necessary but is probably better to be safe than sorry. +func (e *Event) UnmarshalJSON(data []byte) (err error) { + *e, err = NewEventFromUntrustedJSON(data) + return +} + +// MarshalJSON implements json.Marshaller +func (e Event) MarshalJSON() ([]byte, error) { + if e.eventJSON == nil { + return nil, fmt.Errorf("gomatrixserverlib: cannot serialise uninitialised Event") + } + return e.eventJSON, nil +} + // UnmarshalJSON implements json.Unmarshaller func (er *EventReference) UnmarshalJSON(data []byte) error { var tuple []rawJSON diff --git a/vendor/src/github.com/matrix-org/gomatrixserverlib/eventauth.go b/vendor/src/github.com/matrix-org/gomatrixserverlib/eventauth.go index 151c6b99e..913e652a9 100644 --- a/vendor/src/github.com/matrix-org/gomatrixserverlib/eventauth.go +++ b/vendor/src/github.com/matrix-org/gomatrixserverlib/eventauth.go @@ -18,7 +18,8 @@ package gomatrixserverlib import ( "encoding/json" "fmt" - "sort" + + "github.com/matrix-org/util" ) const ( @@ -43,105 +44,108 @@ type StateNeeded struct { ThirdPartyInvite []string } -// StateNeededForAuth returns the event types and state_keys needed to authenticate an event. -// This takes a list of events to facilitate bulk processing when doing auth checks as part of state conflict resolution. -func StateNeededForAuth(events []Event) (result StateNeeded) { - var members []string - var thirdpartyinvites []string - - for _, event := range events { - switch event.Type() { - case "m.room.create": - // The create event doesn't require any state to authenticate. - // https://github.com/matrix-org/synapse/blob/v0.18.5/synapse/api/auth.py#L123 - case "m.room.aliases": - // Alias events need: - // * The create event. - // https://github.com/matrix-org/synapse/blob/v0.18.5/synapse/api/auth.py#L128 - // Alias events need no further authentication. - // https://github.com/matrix-org/synapse/blob/v0.18.5/synapse/api/auth.py#L160 - result.Create = true - case "m.room.member": - // Member events need: - // * The previous membership of the target. - // https://github.com/matrix-org/synapse/blob/v0.18.5/synapse/api/auth.py#L355 - // * The current membership state of the sender. - // https://github.com/matrix-org/synapse/blob/v0.18.5/synapse/api/auth.py#L348 - // * The join rules for the room if the event is a join event. - // https://github.com/matrix-org/synapse/blob/v0.18.5/synapse/api/auth.py#L361 - // * The power levels for the room. - // https://github.com/matrix-org/synapse/blob/v0.18.5/synapse/api/auth.py#L370 - // * And optionally may require a m.third_party_invite event - // https://github.com/matrix-org/synapse/blob/v0.18.5/synapse/api/auth.py#L393 - content, err := newMemberContentFromEvent(event) - if err != nil { - // If we hit an error decoding the content we ignore it here. - // The event will be rejected when the actual checks encounter the same error. - continue - } - result.Create = true - result.PowerLevels = true - stateKey := event.StateKey() - if stateKey != nil { - members = append(members, event.Sender(), *stateKey) - } - if content.Membership == join { - result.JoinRules = true - } - if content.ThirdPartyInvite != nil { - token, err := thirdPartyInviteToken(content.ThirdPartyInvite) - if err != nil { - // If we hit an error decoding the content we ignore it here. - // The event will be rejected when the actual checks encounter the same error. - continue - } else { - thirdpartyinvites = append(thirdpartyinvites, token) - } - } - - default: - // All other events need: - // * The membership of the sender. - // https://github.com/matrix-org/synapse/blob/v0.18.5/synapse/api/auth.py#L177 - // * The power levels for the room. - // https://github.com/matrix-org/synapse/blob/v0.18.5/synapse/api/auth.py#L196 - result.Create = true - result.PowerLevels = true - members = append(members, event.Sender()) +// StateNeededForEventBuilder returns the event types and state_keys needed to authenticate the +// event being built. These events should be put under 'auth_events' for the event being built. +// Returns an error if the state needed could not be calculated with the given builder, e.g +// if there is a m.room.member without a membership key. +func StateNeededForEventBuilder(builder *EventBuilder) (result StateNeeded, err error) { + // Extract the 'content' object from the event if it is m.room.member as we need to know 'membership' + var content *memberContent + if builder.Type == "m.room.member" { + if err = json.Unmarshal(builder.content, &content); err != nil { + err = errorf("unparsable member event content: %s", err.Error()) + return } } - - // Deduplicate the state keys. - sort.Strings(members) - result.Member = members[:unique(sort.StringSlice(members))] - sort.Strings(thirdpartyinvites) - result.ThirdPartyInvite = thirdpartyinvites[:unique(sort.StringSlice(thirdpartyinvites))] + err = accumulateStateNeeded(&result, builder.Type, builder.Sender, builder.StateKey, content) + result.Member = util.UniqueStrings(result.Member) + result.ThirdPartyInvite = util.UniqueStrings(result.ThirdPartyInvite) return } -// Remove duplicate items from a sorted list. -// Takes the same interface as sort.Sort -// Returns the length of the data without duplicates -// Uses the last occurrence of a duplicate. -// O(n). -func unique(data sort.Interface) int { - length := data.Len() - if length == 0 { - return 0 - } - j := 0 - for i := 1; i < length; i++ { - if data.Less(i-1, i) { - data.Swap(i-1, j) - j++ +// StateNeededForAuth returns the event types and state_keys needed to authenticate an event. +// This takes a list of events to facilitate bulk processing when doing auth checks as part of state conflict resolution. +func StateNeededForAuth(events []Event) (result StateNeeded) { + for _, event := range events { + // Extract the 'content' object from the event if it is m.room.member as we need to know 'membership' + var content *memberContent + if event.Type() == "m.room.member" { + c, err := newMemberContentFromEvent(event) + if err == nil { + content = &c + } } + // Ignore errors when accumulating state needed. + // The event will be rejected when the actual checks encounter the same error. + _ = accumulateStateNeeded(&result, event.Type(), event.Sender(), event.StateKey(), content) } - data.Swap(length-1, j) - return j + 1 + + // Deduplicate the state keys. + result.Member = util.UniqueStrings(result.Member) + result.ThirdPartyInvite = util.UniqueStrings(result.ThirdPartyInvite) + return +} + +func accumulateStateNeeded(result *StateNeeded, eventType, sender string, stateKey *string, content *memberContent) (err error) { + switch eventType { + case "m.room.create": + // The create event doesn't require any state to authenticate. + // https://github.com/matrix-org/synapse/blob/v0.18.5/synapse/api/auth.py#L123 + case "m.room.aliases": + // Alias events need: + // * The create event. + // https://github.com/matrix-org/synapse/blob/v0.18.5/synapse/api/auth.py#L128 + // Alias events need no further authentication. + // https://github.com/matrix-org/synapse/blob/v0.18.5/synapse/api/auth.py#L160 + result.Create = true + case "m.room.member": + // Member events need: + // * The previous membership of the target. + // https://github.com/matrix-org/synapse/blob/v0.18.5/synapse/api/auth.py#L355 + // * The current membership state of the sender. + // https://github.com/matrix-org/synapse/blob/v0.18.5/synapse/api/auth.py#L348 + // * The join rules for the room if the event is a join event. + // https://github.com/matrix-org/synapse/blob/v0.18.5/synapse/api/auth.py#L361 + // * The power levels for the room. + // https://github.com/matrix-org/synapse/blob/v0.18.5/synapse/api/auth.py#L370 + // * And optionally may require a m.third_party_invite event + // https://github.com/matrix-org/synapse/blob/v0.18.5/synapse/api/auth.py#L393 + if content == nil { + err = errorf("missing memberContent for m.room.member event") + return + } + result.Create = true + result.PowerLevels = true + if stateKey != nil { + result.Member = append(result.Member, sender, *stateKey) + } + if content.Membership == join { + result.JoinRules = true + } + if content.ThirdPartyInvite != nil { + token, tokErr := thirdPartyInviteToken(content.ThirdPartyInvite) + if tokErr != nil { + err = errorf("could not get third-party token: %s", tokErr) + return + } + result.ThirdPartyInvite = append(result.ThirdPartyInvite, token) + } + + default: + // All other events need: + // * The membership of the sender. + // https://github.com/matrix-org/synapse/blob/v0.18.5/synapse/api/auth.py#L177 + // * The power levels for the room. + // https://github.com/matrix-org/synapse/blob/v0.18.5/synapse/api/auth.py#L196 + result.Create = true + result.PowerLevels = true + result.Member = append(result.Member, sender) + } + return } // thirdPartyInviteToken extracts the token from the third_party_invite. -func thirdPartyInviteToken(thirdPartyInviteData json.RawMessage) (string, error) { +func thirdPartyInviteToken(thirdPartyInviteData rawJSON) (string, error) { var thirdPartyInvite struct { Signed struct { Token string `json:"token"` diff --git a/vendor/src/github.com/matrix-org/gomatrixserverlib/eventauth_test.go b/vendor/src/github.com/matrix-org/gomatrixserverlib/eventauth_test.go index ae9851146..71822e6d9 100644 --- a/vendor/src/github.com/matrix-org/gomatrixserverlib/eventauth_test.go +++ b/vendor/src/github.com/matrix-org/gomatrixserverlib/eventauth_test.go @@ -68,7 +68,7 @@ func (tel *testEventList) UnmarshalJSON(data []byte) error { return nil } -func testStateNeededForAuth(t *testing.T, eventdata string, want StateNeeded) { +func testStateNeededForAuth(t *testing.T, eventdata string, builder *EventBuilder, want StateNeeded) { var events testEventList if err := json.Unmarshal([]byte(eventdata), &events); err != nil { panic(err) @@ -77,11 +77,24 @@ func testStateNeededForAuth(t *testing.T, eventdata string, want StateNeeded) { if !stateNeededEquals(got, want) { t.Errorf("Testing StateNeededForAuth(%#v), wanted %#v got %#v", events, want, got) } + if builder != nil { + got, err := StateNeededForEventBuilder(builder) + if !stateNeededEquals(got, want) { + t.Errorf("Testing StateNeededForEventBuilder(%#v), wanted %#v got %#v", events, want, got) + } + if err != nil { + panic(err) + } + } } func TestStateNeededForCreate(t *testing.T) { // Create events don't need anything. - testStateNeededForAuth(t, `[{"type": "m.room.create"}]`, StateNeeded{}) + skey := "" + testStateNeededForAuth(t, `[{"type": "m.room.create"}]`, &EventBuilder{ + Type: "m.room.create", + StateKey: &skey, + }, StateNeeded{}) } func TestStateNeededForMessage(t *testing.T) { @@ -89,7 +102,10 @@ func TestStateNeededForMessage(t *testing.T) { testStateNeededForAuth(t, `[{ "type": "m.room.message", "sender": "@u1:a" - }]`, StateNeeded{ + }]`, &EventBuilder{ + Type: "m.room.message", + Sender: "@u1:a", + }, StateNeeded{ Create: true, PowerLevels: true, Member: []string{"@u1:a"}, @@ -98,18 +114,27 @@ func TestStateNeededForMessage(t *testing.T) { func TestStateNeededForAlias(t *testing.T) { // Alias events need only the create event. - testStateNeededForAuth(t, `[{"type": "m.room.aliases"}]`, StateNeeded{ + testStateNeededForAuth(t, `[{"type": "m.room.aliases"}]`, &EventBuilder{ + Type: "m.room.aliases", + }, StateNeeded{ Create: true, }) } func TestStateNeededForJoin(t *testing.T) { + skey := "@u1:a" + b := EventBuilder{ + Type: "m.room.member", + StateKey: &skey, + Sender: "@u1:a", + } + b.SetContent(memberContent{"join", nil}) testStateNeededForAuth(t, `[{ "type": "m.room.member", "state_key": "@u1:a", "sender": "@u1:a", "content": {"membership": "join"} - }]`, StateNeeded{ + }]`, &b, StateNeeded{ Create: true, JoinRules: true, PowerLevels: true, @@ -118,12 +143,19 @@ func TestStateNeededForJoin(t *testing.T) { } func TestStateNeededForInvite(t *testing.T) { + skey := "@u2:b" + b := EventBuilder{ + Type: "m.room.member", + StateKey: &skey, + Sender: "@u1:a", + } + b.SetContent(memberContent{"invite", nil}) testStateNeededForAuth(t, `[{ "type": "m.room.member", "state_key": "@u2:b", "sender": "@u1:a", "content": {"membership": "invite"} - }]`, StateNeeded{ + }]`, &b, StateNeeded{ Create: true, PowerLevels: true, Member: []string{"@u1:a", "@u2:b"}, @@ -131,6 +163,13 @@ func TestStateNeededForInvite(t *testing.T) { } func TestStateNeededForInvite3PID(t *testing.T) { + skey := "@u2:b" + b := EventBuilder{ + Type: "m.room.member", + StateKey: &skey, + Sender: "@u1:a", + } + b.SetContent(memberContent{"invite", rawJSON(`{"signed":{"token":"my_token"}}`)}) testStateNeededForAuth(t, `[{ "type": "m.room.member", "state_key": "@u2:b", @@ -143,7 +182,7 @@ func TestStateNeededForInvite3PID(t *testing.T) { } } } - }]`, StateNeeded{ + }]`, &b, StateNeeded{ Create: true, PowerLevels: true, Member: []string{"@u1:a", "@u2:b"}, diff --git a/vendor/src/github.com/matrix-org/gomatrixserverlib/eventcontent.go b/vendor/src/github.com/matrix-org/gomatrixserverlib/eventcontent.go index 2d61d5d52..c355f8b7a 100644 --- a/vendor/src/github.com/matrix-org/gomatrixserverlib/eventcontent.go +++ b/vendor/src/github.com/matrix-org/gomatrixserverlib/eventcontent.go @@ -108,7 +108,7 @@ type memberContent struct { // We use the membership key in order to check if the user is in the room. Membership string `json:"membership"` // We use the third_party_invite key to special case thirdparty invites. - ThirdPartyInvite json.RawMessage `json:"third_party_invite"` + ThirdPartyInvite rawJSON `json:"third_party_invite,omitempty"` } // newMemberContentFromAuthEvents loads the member content from the member event for the user ID in the auth events. From e667f17e14e3fec31d9d711e44cca84ef3aadd46 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Thu, 9 Mar 2017 15:07:18 +0000 Subject: [PATCH 2/2] Include the requested current state alongside the latest events in the query API. (#30) * Return the requested portions of current state in the query API * Use Unique from github.com/matrix-org/util * rewrite bulkSelectFilteredStateBlockEntries to use append for clarity * Add test for stateKeyTupleSorter * Replace current with a new StateEntryList rather than individually setting the fields --- .../dendrite/roomserver/input/events.go | 3 - .../dendrite/roomserver/query/query.go | 39 ++++- .../roomserver-integration-tests/main.go | 10 +- .../dendrite/roomserver/state/state.go | 113 ++++++++++++++- .../roomserver/storage/event_types_table.go | 32 ++++- .../roomserver/storage/rooms_table.go | 11 +- .../roomserver/storage/state_block_table.go | 133 ++++++++++++++++-- .../storage/state_block_table_test.go | 70 +++++++++ .../dendrite/roomserver/storage/storage.go | 32 +++-- vendor/manifest | 2 +- 10 files changed, 410 insertions(+), 35 deletions(-) create mode 100644 src/github.com/matrix-org/dendrite/roomserver/storage/state_block_table_test.go diff --git a/src/github.com/matrix-org/dendrite/roomserver/input/events.go b/src/github.com/matrix-org/dendrite/roomserver/input/events.go index 600ec6ae9..e7b791b03 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/input/events.go +++ b/src/github.com/matrix-org/dendrite/roomserver/input/events.go @@ -16,9 +16,6 @@ type RoomEventDatabase interface { // Returns an error if the there is an error talking to the database // or if the event IDs aren't in the database. StateEntriesForEventIDs(eventIDs []string) ([]types.StateEntry, error) - // Lookup the numeric IDs for a list of string event state keys. - // Returns a map from string state key to numeric ID for the state key. - EventStateKeyNIDs(eventStateKeys []string) (map[string]types.EventStateKeyNID, error) // Lookup the Events for a list of numeric event IDs. // Returns a sorted list of events. Events(eventNIDs []types.EventNID) ([]types.Event, error) diff --git a/src/github.com/matrix-org/dendrite/roomserver/query/query.go b/src/github.com/matrix-org/dendrite/roomserver/query/query.go index 0be51886f..b08b99e7e 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/query/query.go +++ b/src/github.com/matrix-org/dendrite/roomserver/query/query.go @@ -3,6 +3,7 @@ package query import ( "encoding/json" "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -12,13 +13,17 @@ import ( // RoomserverQueryAPIDatabase has the storage APIs needed to implement the query API. type RoomserverQueryAPIDatabase interface { + state.RoomStateDatabase // Lookup the numeric ID for the room. // Returns 0 if the room doesn't exists. // Returns an error if there was a problem talking to the database. RoomNID(roomID string) (types.RoomNID, error) - // Lookup event references for the latest events in the room. + // Lookup event references for the latest events in the room and the current state snapshot. // Returns an error if there was a problem talking to the database. - LatestEventIDs(roomNID types.RoomNID) ([]gomatrixserverlib.EventReference, error) + LatestEventIDs(roomNID types.RoomNID) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, error) + // Lookup the Events for a list of numeric event IDs. + // Returns a list of events sorted by numeric event ID. + Events(eventNIDs []types.EventNID) ([]types.Event, error) } // RoomserverQueryAPI is an implementation of RoomserverQueryAPI @@ -40,9 +45,33 @@ func (r *RoomserverQueryAPI) QueryLatestEventsAndState( return nil } response.RoomExists = true - response.LatestEvents, err = r.DB.LatestEventIDs(roomNID) - // TODO: look up the current state. - return err + var currentStateSnapshotNID types.StateSnapshotNID + response.LatestEvents, currentStateSnapshotNID, err = r.DB.LatestEventIDs(roomNID) + if err != nil { + return err + } + + // Lookup the currrent state for the requested tuples. + stateEntries, err := state.LoadStateAtSnapshotForStringTuples(r.DB, currentStateSnapshotNID, request.StateToFetch) + if err != nil { + return err + } + + eventNIDs := make([]types.EventNID, len(stateEntries)) + for i := range stateEntries { + eventNIDs[i] = stateEntries[i].EventNID + } + + stateEvents, err := r.DB.Events(eventNIDs) + if err != nil { + return err + } + + response.StateEvents = make([]gomatrixserverlib.Event, len(stateEvents)) + for i := range stateEvents { + response.StateEvents[i] = stateEvents[i].Event + } + return nil } // SetupHTTP adds the RoomserverQueryAPI handlers to the http.ServeMux. diff --git a/src/github.com/matrix-org/dendrite/roomserver/roomserver-integration-tests/main.go b/src/github.com/matrix-org/dendrite/roomserver/roomserver-integration-tests/main.go index e7cd2d1fa..96a00577b 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/roomserver-integration-tests/main.go +++ b/src/github.com/matrix-org/dendrite/roomserver/roomserver-integration-tests/main.go @@ -365,7 +365,12 @@ func main() { testRoomserver(input, want, func(q api.RoomserverQueryAPI) { var response api.QueryLatestEventsAndStateResponse if err := q.QueryLatestEventsAndState( - &api.QueryLatestEventsAndStateRequest{RoomID: "!HCXfdvrfksxuYnIFiJ:matrix.org"}, + &api.QueryLatestEventsAndStateRequest{ + RoomID: "!HCXfdvrfksxuYnIFiJ:matrix.org", + StateToFetch: []api.StateKeyTuple{ + {"m.room.member", "@richvdh:matrix.org"}, + }, + }, &response, ); err != nil { panic(err) @@ -376,6 +381,9 @@ func main() { if len(response.LatestEvents) != 1 || response.LatestEvents[0].EventID != "$1463671339126270PnVwC:matrix.org" { panic(fmt.Errorf(`Wanted "$1463671339126270PnVwC:matrix.org" to be the latest event got %#v`, response.LatestEvents)) } + if len(response.StateEvents) != 1 || response.StateEvents[0].EventID() != "$1463671339126270PnVwC:matrix.org" { + panic(fmt.Errorf(`Wanted "$1463671339126270PnVwC:matrix.org" to be the state event got %#v`, response.StateEvents)) + } }) fmt.Println("==PASSED==", os.Args[0]) diff --git a/src/github.com/matrix-org/dendrite/roomserver/state/state.go b/src/github.com/matrix-org/dendrite/roomserver/state/state.go index aadc5550f..ec7b8e08a 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/state/state.go +++ b/src/github.com/matrix-org/dendrite/roomserver/state/state.go @@ -4,6 +4,7 @@ package state import ( "fmt" + "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/util" "sort" @@ -11,12 +12,25 @@ import ( // A RoomStateDatabase has the storage APIs needed to load state from the database type RoomStateDatabase interface { + // Lookup the numeric IDs for a list of string event types. + // Returns a map from string event type to numeric ID for the event type. + EventTypeNIDs(eventTypes []string) (map[string]types.EventTypeNID, error) + // Lookup the numeric IDs for a list of string event state keys. + // Returns a map from string state key to numeric ID for the state key. + EventStateKeyNIDs(eventStateKeys []string) (map[string]types.EventStateKeyNID, error) // Lookup the numeric state data IDs for each numeric state snapshot ID // The returned slice is sorted by numeric state snapshot ID. StateBlockNIDs(stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error) // Lookup the state data for each numeric state data ID // The returned slice is sorted by numeric state data ID. StateEntries(stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error) + // Lookup the state data for the state key tuples for each numeric state block ID + // This is used to fetch a subset of the room state at a snapshot. + // If a block doesn't contain any of the requested tuples then it can be discarded from the result. + // The returned slice is sorted by numeric state block ID. + StateEntriesForTuples(stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ( + []types.StateEntryList, error, + ) } // LoadStateAtSnapshot loads the full state of a room at a particular snapshot. @@ -27,6 +41,7 @@ func LoadStateAtSnapshot(db RoomStateDatabase, stateNID types.StateSnapshotNID) if err != nil { return nil, err } + // We've asked for exactly one snapshot from the db so we should have exactly one entry in the result. stateBlockNIDList := stateBlockNIDLists[0] stateEntryLists, err := db.StateEntries(stateBlockNIDList.StateBlockNIDs) @@ -35,7 +50,7 @@ func LoadStateAtSnapshot(db RoomStateDatabase, stateNID types.StateSnapshotNID) } stateEntriesMap := stateEntryListMap(stateEntryLists) - // Combined all the state entries for this snapshot. + // Combine all the state entries for this snapshot. // The order of state block NIDs in the list tells us the order to combine them in. var fullState []types.StateEntry for _, stateBlockNID := range stateBlockNIDList.StateBlockNIDs { @@ -98,7 +113,7 @@ func LoadCombinedStateAfterEvents(db RoomStateDatabase, prevStates []types.State panic(fmt.Errorf("Corrupt DB: Missing state snapshot numeric ID %d", prevState.BeforeStateSnapshotNID)) } - // Combined all the state entries for this snapshot. + // Combine all the state entries for this snapshot. // The order of state block NIDs in the list tells us the order to combine them in. var fullState []types.StateEntry for _, stateBlockNID := range stateBlockNIDs { @@ -182,6 +197,100 @@ func DifferenceBetweeenStateSnapshots(db RoomStateDatabase, oldStateNID, newStat } } +// stringTuplesToNumericTuples converts the string state key tuples into numeric IDs +// If there isn't a numeric ID for either the event type or the event state key then the tuple is discarded. +// Returns an error if there was a problem talking to the database. +func stringTuplesToNumericTuples(db RoomStateDatabase, stringTuples []api.StateKeyTuple) ([]types.StateKeyTuple, error) { + eventTypes := make([]string, len(stringTuples)) + stateKeys := make([]string, len(stringTuples)) + for i := range stringTuples { + eventTypes[i] = stringTuples[i].EventType + stateKeys[i] = stringTuples[i].EventStateKey + } + eventTypes = util.UniqueStrings(eventTypes) + eventTypeMap, err := db.EventTypeNIDs(eventTypes) + if err != nil { + return nil, err + } + stateKeys = util.UniqueStrings(stateKeys) + stateKeyMap, err := db.EventStateKeyNIDs(stateKeys) + if err != nil { + return nil, err + } + + var result []types.StateKeyTuple + for _, stringTuple := range stringTuples { + var numericTuple types.StateKeyTuple + var ok1, ok2 bool + numericTuple.EventTypeNID, ok1 = eventTypeMap[stringTuple.EventType] + numericTuple.EventStateKeyNID, ok2 = stateKeyMap[stringTuple.EventStateKey] + // Discard the tuple if there wasn't a numeric ID for either the event type or the state key. + if ok1 && ok2 { + result = append(result, numericTuple) + } + } + + return result, nil +} + +// LoadStateAtSnapshotForStringTuples loads the state for a list of event type and state key pairs at a snapshot. +// This is used when we only want to load a subset of the room state at a snapshot. +// If there is no entry for a given event type and state key pair then it will be discarded. +// This is typically the state before an event or the current state of a room. +// Returns a sorted list of state entries or an error if there was a problem talking to the database. +func LoadStateAtSnapshotForStringTuples( + db RoomStateDatabase, stateNID types.StateSnapshotNID, stateKeyTuples []api.StateKeyTuple, +) ([]types.StateEntry, error) { + numericTuples, err := stringTuplesToNumericTuples(db, stateKeyTuples) + if err != nil { + return nil, err + } + return loadStateAtSnapshotForNumericTuples(db, stateNID, numericTuples) +} + +// loadStateAtSnapshotForNumericTuples loads the state for a list of event type and state key pairs at a snapshot. +// This is used when we only want to load a subset of the room state at a snapshot. +// If there is no entry for a given event type and state key pair then it will be discarded. +// This is typically the state before an event or the current state of a room. +// Returns a sorted list of state entries or an error if there was a problem talking to the database. +func loadStateAtSnapshotForNumericTuples( + db RoomStateDatabase, stateNID types.StateSnapshotNID, stateKeyTuples []types.StateKeyTuple, +) ([]types.StateEntry, error) { + stateBlockNIDLists, err := db.StateBlockNIDs([]types.StateSnapshotNID{stateNID}) + if err != nil { + return nil, err + } + // We've asked for exactly one snapshot from the db so we should have exactly one entry in the result. + stateBlockNIDList := stateBlockNIDLists[0] + + stateEntryLists, err := db.StateEntriesForTuples(stateBlockNIDList.StateBlockNIDs, stateKeyTuples) + if err != nil { + return nil, err + } + stateEntriesMap := stateEntryListMap(stateEntryLists) + + // Combine all the state entries for this snapshot. + // The order of state block NIDs in the list tells us the order to combine them in. + var fullState []types.StateEntry + for _, stateBlockNID := range stateBlockNIDList.StateBlockNIDs { + entries, ok := stateEntriesMap.lookup(stateBlockNID) + if !ok { + // If the block is missing from the map it means that none of its entries matched a requested tuple. + // This can happen if the block doesn't contain an update for one of the requested tuples. + // If none of the requested tuples are in the block then it can be safely skipped. + continue + } + fullState = append(fullState, entries...) + } + + // Stable sort so that the most recent entry for each state key stays + // remains later in the list than the older entries for the same state key. + sort.Stable(stateEntryByStateKeySorter(fullState)) + // Unique returns the last entry and hence the most recent entry for each state key. + fullState = fullState[:util.Unique(stateEntryByStateKeySorter(fullState))] + return fullState, nil +} + type stateBlockNIDListMap []types.StateBlockNIDList func (m stateBlockNIDListMap) lookup(stateNID types.StateSnapshotNID) (stateBlockNIDs []types.StateBlockNID, ok bool) { diff --git a/src/github.com/matrix-org/dendrite/roomserver/storage/event_types_table.go b/src/github.com/matrix-org/dendrite/roomserver/storage/event_types_table.go index 26fe05388..197fa8df4 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/storage/event_types_table.go +++ b/src/github.com/matrix-org/dendrite/roomserver/storage/event_types_table.go @@ -2,6 +2,7 @@ package storage import ( "database/sql" + "github.com/lib/pq" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -66,9 +67,16 @@ const insertEventTypeNIDSQL = "" + const selectEventTypeNIDSQL = "" + "SELECT event_type_nid FROM event_types WHERE event_type = $1" +// Bulk lookup from string event type to numeric ID for that event type. +// Takes an array of strings as the query parameter. +const bulkSelectEventTypeNIDSQL = "" + + "SELECT event_type, event_type_nid FROM event_types" + + " WHERE event_type = ANY($1)" + type eventTypeStatements struct { - insertEventTypeNIDStmt *sql.Stmt - selectEventTypeNIDStmt *sql.Stmt + insertEventTypeNIDStmt *sql.Stmt + selectEventTypeNIDStmt *sql.Stmt + bulkSelectEventTypeNIDStmt *sql.Stmt } func (s *eventTypeStatements) prepare(db *sql.DB) (err error) { @@ -80,6 +88,7 @@ func (s *eventTypeStatements) prepare(db *sql.DB) (err error) { return statementList{ {&s.insertEventTypeNIDStmt, insertEventTypeNIDSQL}, {&s.selectEventTypeNIDStmt, selectEventTypeNIDSQL}, + {&s.bulkSelectEventTypeNIDStmt, bulkSelectEventTypeNIDSQL}, }.prepare(db) } @@ -94,3 +103,22 @@ func (s *eventTypeStatements) selectEventTypeNID(eventType string) (types.EventT err := s.selectEventTypeNIDStmt.QueryRow(eventType).Scan(&eventTypeNID) return types.EventTypeNID(eventTypeNID), err } + +func (s *eventTypeStatements) bulkSelectEventTypeNID(eventTypes []string) (map[string]types.EventTypeNID, error) { + rows, err := s.bulkSelectEventTypeNIDStmt.Query(pq.StringArray(eventTypes)) + if err != nil { + return nil, err + } + defer rows.Close() + + result := make(map[string]types.EventTypeNID, len(eventTypes)) + for rows.Next() { + var eventType string + var eventTypeNID int64 + if err := rows.Scan(&eventType, &eventTypeNID); err != nil { + return nil, err + } + result[eventType] = types.EventTypeNID(eventTypeNID) + } + return result, nil +} diff --git a/src/github.com/matrix-org/dendrite/roomserver/storage/rooms_table.go b/src/github.com/matrix-org/dendrite/roomserver/storage/rooms_table.go index f81715450..ed6bd06f2 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/storage/rooms_table.go +++ b/src/github.com/matrix-org/dendrite/roomserver/storage/rooms_table.go @@ -35,7 +35,7 @@ const selectRoomNIDSQL = "" + "SELECT room_nid FROM rooms WHERE room_id = $1" const selectLatestEventNIDsSQL = "" + - "SELECT latest_event_nids FROM rooms WHERE room_nid = $1" + "SELECT latest_event_nids, state_snapshot_nid FROM rooms WHERE room_nid = $1" const selectLatestEventNIDsForUpdateSQL = "" + "SELECT latest_event_nids, last_event_sent_nid, state_snapshot_nid FROM rooms WHERE room_nid = $1 FOR UPDATE" @@ -77,17 +77,18 @@ func (s *roomStatements) selectRoomNID(roomID string) (types.RoomNID, error) { return types.RoomNID(roomNID), err } -func (s *roomStatements) selectLatestEventNIDs(roomNID types.RoomNID) ([]types.EventNID, error) { +func (s *roomStatements) selectLatestEventNIDs(roomNID types.RoomNID) ([]types.EventNID, types.StateSnapshotNID, error) { var nids pq.Int64Array - err := s.selectLatestEventNIDsStmt.QueryRow(int64(roomNID)).Scan(&nids) + var stateSnapshotNID int64 + err := s.selectLatestEventNIDsStmt.QueryRow(int64(roomNID)).Scan(&nids, &stateSnapshotNID) if err != nil { - return nil, err + return nil, 0, err } eventNIDs := make([]types.EventNID, len(nids)) for i := range nids { eventNIDs[i] = types.EventNID(nids[i]) } - return eventNIDs, nil + return eventNIDs, types.StateSnapshotNID(stateSnapshotNID), nil } func (s *roomStatements) selectLatestEventsNIDsForUpdate(txn *sql.Tx, roomNID types.RoomNID) ( diff --git a/src/github.com/matrix-org/dendrite/roomserver/storage/state_block_table.go b/src/github.com/matrix-org/dendrite/roomserver/storage/state_block_table.go index 440b539b1..14fed0f56 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/storage/state_block_table.go +++ b/src/github.com/matrix-org/dendrite/roomserver/storage/state_block_table.go @@ -5,6 +5,8 @@ import ( "fmt" "github.com/lib/pq" "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/util" + "sort" ) const stateDataSchema = ` @@ -35,21 +37,35 @@ const insertStateDataSQL = "" + const selectNextStateBlockNIDSQL = "" + "SELECT nextval('state_block_nid_seq')" -// Bulk state lookup by numeric event ID. +// Bulk state lookup by numeric state block ID. // Sort by the state_block_nid, event_type_nid, event_state_key_nid // This means that all the entries for a given state_block_nid will appear // together in the list and those entries will sorted by event_type_nid // and event_state_key_nid. This property makes it easier to merge two // state data blocks together. -const bulkSelectStateDataEntriesSQL = "" + +const bulkSelectStateBlockEntriesSQL = "" + "SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" + " FROM state_block WHERE state_block_nid = ANY($1)" + " ORDER BY state_block_nid, event_type_nid, event_state_key_nid" +// Bulk state lookup by numeric state block ID. +// Filters the rows in each block to the requested types and state keys. +// We would like to restrict to particular type state key pairs but we are +// restricted by the query language to pull the cross product of a list +// of types and a list state_keys. So we have to filter the result in the +// application to restrict it to the list of event types and state keys we +// actually wanted. +const bulkSelectFilteredStateBlockEntriesSQL = "" + + "SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" + + " FROM state_block WHERE state_block_nid = ANY($1)" + + " AND event_type_nid = ANY($2) AND event_state_key_nid = ANY($3)" + + " ORDER BY state_block_nid, event_type_nid, event_state_key_nid" + type stateBlockStatements struct { - insertStateDataStmt *sql.Stmt - selectNextStateBlockNIDStmt *sql.Stmt - bulkSelectStateDataEntriesStmt *sql.Stmt + insertStateDataStmt *sql.Stmt + selectNextStateBlockNIDStmt *sql.Stmt + bulkSelectStateBlockEntriesStmt *sql.Stmt + bulkSelectFilteredStateBlockEntriesStmt *sql.Stmt } func (s *stateBlockStatements) prepare(db *sql.DB) (err error) { @@ -61,7 +77,8 @@ func (s *stateBlockStatements) prepare(db *sql.DB) (err error) { return statementList{ {&s.insertStateDataStmt, insertStateDataSQL}, {&s.selectNextStateBlockNIDStmt, selectNextStateBlockNIDSQL}, - {&s.bulkSelectStateDataEntriesStmt, bulkSelectStateDataEntriesSQL}, + {&s.bulkSelectStateBlockEntriesStmt, bulkSelectStateBlockEntriesSQL}, + {&s.bulkSelectFilteredStateBlockEntriesStmt, bulkSelectFilteredStateBlockEntriesSQL}, }.prepare(db) } @@ -86,12 +103,12 @@ func (s *stateBlockStatements) selectNextStateBlockNID() (types.StateBlockNID, e return types.StateBlockNID(stateBlockNID), err } -func (s *stateBlockStatements) bulkSelectStateDataEntries(stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error) { +func (s *stateBlockStatements) bulkSelectStateBlockEntries(stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error) { nids := make([]int64, len(stateBlockNIDs)) for i := range stateBlockNIDs { nids[i] = int64(stateBlockNIDs[i]) } - rows, err := s.bulkSelectStateDataEntriesStmt.Query(pq.Int64Array(nids)) + rows, err := s.bulkSelectStateBlockEntriesStmt.Query(pq.Int64Array(nids)) if err != nil { return nil, err } @@ -131,3 +148,103 @@ func (s *stateBlockStatements) bulkSelectStateDataEntries(stateBlockNIDs []types } return results, nil } + +func (s *stateBlockStatements) bulkSelectFilteredStateBlockEntries( + stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple, +) ([]types.StateEntryList, error) { + tuples := stateKeyTupleSorter(stateKeyTuples) + // Sort the tuples so that we can run binary search against them as we filter the rows returned by the db. + sort.Sort(tuples) + + eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays() + rows, err := s.bulkSelectFilteredStateBlockEntriesStmt.Query( + stateBlockNIDsAsArray(stateBlockNIDs), eventTypeNIDArray, eventStateKeyNIDArray, + ) + if err != nil { + return nil, err + } + defer rows.Close() + + var results []types.StateEntryList + var current types.StateEntryList + for rows.Next() { + var ( + stateBlockNID int64 + eventTypeNID int64 + eventStateKeyNID int64 + eventNID int64 + entry types.StateEntry + ) + if err := rows.Scan( + &stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID, + ); err != nil { + return nil, err + } + entry.EventTypeNID = types.EventTypeNID(eventTypeNID) + entry.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID) + entry.EventNID = types.EventNID(eventNID) + + // We can use binary search here because we sorted the tuples earlier + if !tuples.contains(entry.StateKeyTuple) { + // The select will return the cross product of types and state keys. + // So we need to check if type of the entry is in the list. + continue + } + + if types.StateBlockNID(stateBlockNID) != current.StateBlockNID { + // The state entry row is for a different state data block to the current one. + // So we append the current entry to the results and start adding to a new one. + // The first time through the loop current will be empty. + if current.StateEntries != nil { + results = append(results, current) + } + current = types.StateEntryList{StateBlockNID: types.StateBlockNID(stateBlockNID)} + } + current.StateEntries = append(current.StateEntries, entry) + } + // Add the last entry to the list if it is not empty. + if current.StateEntries != nil { + results = append(results, current) + } + return results, nil +} + +func stateBlockNIDsAsArray(stateBlockNIDs []types.StateBlockNID) pq.Int64Array { + nids := make([]int64, len(stateBlockNIDs)) + for i := range stateBlockNIDs { + nids[i] = int64(stateBlockNIDs[i]) + } + return pq.Int64Array(nids) +} + +type stateKeyTupleSorter []types.StateKeyTuple + +func (s stateKeyTupleSorter) Len() int { return len(s) } +func (s stateKeyTupleSorter) Less(i, j int) bool { return s[i].LessThan(s[j]) } +func (s stateKeyTupleSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +// Check whether a tuple is in the list. Assumes that the list is sorted. +func (s stateKeyTupleSorter) contains(value types.StateKeyTuple) bool { + i := sort.Search(len(s), func(i int) bool { return !s[i].LessThan(value) }) + return i < len(s) && s[i] == value +} + +// List the unique eventTypeNIDs and eventStateKeyNIDs. +// Assumes that the list is sorted. +func (s stateKeyTupleSorter) typesAndStateKeysAsArrays() (eventTypeNIDs pq.Int64Array, eventStateKeyNIDs pq.Int64Array) { + eventTypeNIDs = make(pq.Int64Array, len(s)) + eventStateKeyNIDs = make(pq.Int64Array, len(s)) + for i := range s { + eventTypeNIDs[i] = int64(s[i].EventTypeNID) + eventStateKeyNIDs[i] = int64(s[i].EventStateKeyNID) + } + eventTypeNIDs = eventTypeNIDs[:util.SortAndUnique(int64Sorter(eventTypeNIDs))] + eventStateKeyNIDs = eventStateKeyNIDs[:util.SortAndUnique(int64Sorter(eventStateKeyNIDs))] + return +} + +type int64Sorter []int64 + +func (s int64Sorter) Len() int { return len(s) } +func (s int64Sorter) Less(i, j int) bool { return s[i] < s[j] } +func (s int64Sorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } diff --git a/src/github.com/matrix-org/dendrite/roomserver/storage/state_block_table_test.go b/src/github.com/matrix-org/dendrite/roomserver/storage/state_block_table_test.go new file mode 100644 index 000000000..e0a142296 --- /dev/null +++ b/src/github.com/matrix-org/dendrite/roomserver/storage/state_block_table_test.go @@ -0,0 +1,70 @@ +package storage + +import ( + "github.com/matrix-org/dendrite/roomserver/types" + "sort" + "testing" +) + +func TestStateKeyTupleSorter(t *testing.T) { + input := stateKeyTupleSorter{ + {1, 2}, + {1, 4}, + {2, 2}, + {1, 1}, + } + want := []types.StateKeyTuple{ + {1, 1}, + {1, 2}, + {1, 4}, + {2, 2}, + } + doNotWant := []types.StateKeyTuple{ + {0, 0}, + {1, 3}, + {2, 1}, + {3, 1}, + } + wantTypeNIDs := []int64{1, 2} + wantStateKeyNIDs := []int64{1, 2, 4} + + // Sort the input and check it's in the right order. + sort.Sort(input) + gotTypeNIDs, gotStateKeyNIDs := input.typesAndStateKeysAsArrays() + + for i := range want { + if input[i] != want[i] { + t.Errorf("Wanted %#v at index %d got %#v", want[i], i, input[i]) + } + + if !input.contains(want[i]) { + t.Errorf("Wanted %#v.contains(%#v) to be true but got false", input, want[i]) + } + } + + for i := range doNotWant { + if input.contains(doNotWant[i]) { + t.Errorf("Wanted %#v.contains(%#v) to be false but got true", input, doNotWant[i]) + } + } + + if len(wantTypeNIDs) != len(gotTypeNIDs) { + t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs) + } + + for i := range wantTypeNIDs { + if wantTypeNIDs[i] != gotTypeNIDs[i] { + t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs) + } + } + + if len(wantStateKeyNIDs) != len(gotStateKeyNIDs) { + t.Fatalf("Wanted state key NIDs %#v got %#v", wantStateKeyNIDs, gotStateKeyNIDs) + } + + for i := range wantStateKeyNIDs { + if wantStateKeyNIDs[i] != gotStateKeyNIDs[i] { + t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs) + } + } +} diff --git a/src/github.com/matrix-org/dendrite/roomserver/storage/storage.go b/src/github.com/matrix-org/dendrite/roomserver/storage/storage.go index 989d91b0b..d75c8d380 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/storage/storage.go +++ b/src/github.com/matrix-org/dendrite/roomserver/storage/storage.go @@ -145,7 +145,12 @@ func (d *Database) StateEntriesForEventIDs(eventIDs []string) ([]types.StateEntr return d.statements.bulkSelectStateEventByID(eventIDs) } -// EventStateKeyNIDs implements input.EventDatabase +// EventTypeNIDs implements state.RoomStateDatabase +func (d *Database) EventTypeNIDs(eventTypes []string) (map[string]types.EventTypeNID, error) { + return d.statements.bulkSelectEventTypeNID(eventTypes) +} + +// EventStateKeyNIDs implements state.RoomStateDatabase func (d *Database) EventStateKeyNIDs(eventStateKeys []string) (map[string]types.EventStateKeyNID, error) { return d.statements.bulkSelectEventStateKeyNID(eventStateKeys) } @@ -195,14 +200,14 @@ func (d *Database) StateAtEventIDs(eventIDs []string) ([]types.StateAtEvent, err return d.statements.bulkSelectStateAtEventByID(eventIDs) } -// StateBlockNIDs implements input.EventDatabase +// StateBlockNIDs implements state.RoomStateDatabase func (d *Database) StateBlockNIDs(stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error) { return d.statements.bulkSelectStateBlockNIDs(stateNIDs) } -// StateEntries implements input.EventDatabase +// StateEntries implements state.RoomStateDatabase func (d *Database) StateEntries(stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error) { - return d.statements.bulkSelectStateDataEntries(stateBlockNIDs) + return d.statements.bulkSelectStateBlockEntries(stateBlockNIDs) } // EventIDs implements input.RoomEventDatabase @@ -324,10 +329,21 @@ func (d *Database) RoomNID(roomID string) (types.RoomNID, error) { } // LatestEventIDs implements query.RoomserverQueryAPIDB -func (d *Database) LatestEventIDs(roomNID types.RoomNID) ([]gomatrixserverlib.EventReference, error) { - eventNIDs, err := d.statements.selectLatestEventNIDs(roomNID) +func (d *Database) LatestEventIDs(roomNID types.RoomNID) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, error) { + eventNIDs, currentStateSnapshotNID, err := d.statements.selectLatestEventNIDs(roomNID) if err != nil { - return nil, err + return nil, 0, err } - return d.statements.bulkSelectEventReference(eventNIDs) + references, err := d.statements.bulkSelectEventReference(eventNIDs) + if err != nil { + return nil, 0, err + } + return references, currentStateSnapshotNID, nil +} + +// StateEntriesForTuples implements state.RoomStateDatabase +func (d *Database) StateEntriesForTuples( + stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple, +) ([]types.StateEntryList, error) { + return d.statements.bulkSelectFilteredStateBlockEntries(stateBlockNIDs, stateKeyTuples) } diff --git a/vendor/manifest b/vendor/manifest index fe5d5bbc4..6a3e8ee53 100644 --- a/vendor/manifest +++ b/vendor/manifest @@ -206,4 +206,4 @@ "branch": "master" } ] -} +} \ No newline at end of file