diff --git a/currentstateserver/acls/acls.go b/currentstateserver/acls/acls.go index 12619f5fc..775b6c73a 100644 --- a/currentstateserver/acls/acls.go +++ b/currentstateserver/acls/acls.go @@ -23,17 +23,25 @@ import ( "strings" "sync" - "github.com/matrix-org/dendrite/currentstateserver/storage" "github.com/matrix-org/gomatrixserverlib" "github.com/sirupsen/logrus" ) +type ServerACLDatabase interface { + // GetKnownRooms returns a list of all rooms we know about. + GetKnownRooms(ctx context.Context) ([]string, error) + // GetStateEvent returns the state event of a given type for a given room with a given state key + // If no event could be found, returns nil + // If there was an issue during the retrieval, returns an error + GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) +} + type ServerACLs struct { acls map[string]*serverACL // room ID -> ACL aclsMutex sync.RWMutex // protects the above } -func NewServerACLs(db storage.Database) *ServerACLs { +func NewServerACLs(db ServerACLDatabase) *ServerACLs { ctx := context.TODO() acls := &ServerACLs{ acls: make(map[string]*serverACL), diff --git a/federationapi/routing/send_test.go b/federationapi/routing/send_test.go index fa745e286..6dc8621b2 100644 --- a/federationapi/routing/send_test.go +++ b/federationapi/routing/send_test.go @@ -296,6 +296,30 @@ func (t *testRoomserverAPI) RemoveRoomAlias( return fmt.Errorf("not implemented") } +func (t *testRoomserverAPI) QueryCurrentState(ctx context.Context, req *api.QueryCurrentStateRequest, res *api.QueryCurrentStateResponse) error { + return nil +} + +func (t *testRoomserverAPI) QueryRoomsForUser(ctx context.Context, req *api.QueryRoomsForUserRequest, res *api.QueryRoomsForUserResponse) error { + return fmt.Errorf("not implemented") +} + +func (t *testRoomserverAPI) QueryBulkStateContent(ctx context.Context, req *api.QueryBulkStateContentRequest, res *api.QueryBulkStateContentResponse) error { + return fmt.Errorf("not implemented") +} + +func (t *testRoomserverAPI) QuerySharedUsers(ctx context.Context, req *api.QuerySharedUsersRequest, res *api.QuerySharedUsersResponse) error { + return fmt.Errorf("not implemented") +} + +func (t *testRoomserverAPI) QueryKnownUsers(ctx context.Context, req *api.QueryKnownUsersRequest, res *api.QueryKnownUsersResponse) error { + return fmt.Errorf("not implemented") +} + +func (t *testRoomserverAPI) QueryServerBannedFromRoom(ctx context.Context, req *api.QueryServerBannedFromRoomRequest, res *api.QueryServerBannedFromRoomResponse) error { + return nil +} + type testStateAPI struct { } diff --git a/roomserver/acls/acls.go b/roomserver/acls/acls.go new file mode 100644 index 000000000..775b6c73a --- /dev/null +++ b/roomserver/acls/acls.go @@ -0,0 +1,164 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package acls + +import ( + "context" + "encoding/json" + "fmt" + "net" + "regexp" + "strings" + "sync" + + "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" +) + +type ServerACLDatabase interface { + // GetKnownRooms returns a list of all rooms we know about. + GetKnownRooms(ctx context.Context) ([]string, error) + // GetStateEvent returns the state event of a given type for a given room with a given state key + // If no event could be found, returns nil + // If there was an issue during the retrieval, returns an error + GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) +} + +type ServerACLs struct { + acls map[string]*serverACL // room ID -> ACL + aclsMutex sync.RWMutex // protects the above +} + +func NewServerACLs(db ServerACLDatabase) *ServerACLs { + ctx := context.TODO() + acls := &ServerACLs{ + acls: make(map[string]*serverACL), + } + // Look up all of the rooms that the current state server knows about. + rooms, err := db.GetKnownRooms(ctx) + if err != nil { + logrus.WithError(err).Fatalf("Failed to get known rooms") + } + // For each room, let's see if we have a server ACL state event. If we + // do then we'll process it into memory so that we have the regexes to + // hand. + for _, room := range rooms { + state, err := db.GetStateEvent(ctx, room, "m.room.server_acl", "") + if err != nil { + logrus.WithError(err).Errorf("Failed to get server ACLs for room %q", room) + continue + } + if state != nil { + acls.OnServerACLUpdate(&state.Event) + } + } + return acls +} + +type ServerACL struct { + Allowed []string `json:"allow"` + Denied []string `json:"deny"` + AllowIPLiterals bool `json:"allow_ip_literals"` +} + +type serverACL struct { + ServerACL + allowedRegexes []*regexp.Regexp + deniedRegexes []*regexp.Regexp +} + +func compileACLRegex(orig string) (*regexp.Regexp, error) { + escaped := regexp.QuoteMeta(orig) + escaped = strings.Replace(escaped, "\\?", ".", -1) + escaped = strings.Replace(escaped, "\\*", ".*", -1) + return regexp.Compile(escaped) +} + +func (s *ServerACLs) OnServerACLUpdate(state *gomatrixserverlib.Event) { + acls := &serverACL{} + if err := json.Unmarshal(state.Content(), &acls.ServerACL); err != nil { + logrus.WithError(err).Errorf("Failed to unmarshal state content for server ACLs") + return + } + // The spec calls only for * (zero or more chars) and ? (exactly one char) + // to be supported as wildcard components, so we will escape all of the regex + // special characters and then replace * and ? with their regex counterparts. + // https://matrix.org/docs/spec/client_server/r0.6.1#m-room-server-acl + for _, orig := range acls.Allowed { + if expr, err := compileACLRegex(orig); err != nil { + logrus.WithError(err).Errorf("Failed to compile allowed regex") + } else { + acls.allowedRegexes = append(acls.allowedRegexes, expr) + } + } + for _, orig := range acls.Denied { + if expr, err := compileACLRegex(orig); err != nil { + logrus.WithError(err).Errorf("Failed to compile denied regex") + } else { + acls.deniedRegexes = append(acls.deniedRegexes, expr) + } + } + logrus.WithFields(logrus.Fields{ + "allow_ip_literals": acls.AllowIPLiterals, + "num_allowed": len(acls.allowedRegexes), + "num_denied": len(acls.deniedRegexes), + }).Debugf("Updating server ACLs for %q", state.RoomID()) + s.aclsMutex.Lock() + defer s.aclsMutex.Unlock() + s.acls[state.RoomID()] = acls +} + +func (s *ServerACLs) IsServerBannedFromRoom(serverName gomatrixserverlib.ServerName, roomID string) bool { + s.aclsMutex.RLock() + // First of all check if we have an ACL for this room. If we don't then + // no servers are banned from the room. + acls, ok := s.acls[roomID] + if !ok { + s.aclsMutex.RUnlock() + return false + } + s.aclsMutex.RUnlock() + // Split the host and port apart. This is because the spec calls on us to + // validate the hostname only in cases where the port is also present. + if serverNameOnly, _, err := net.SplitHostPort(string(serverName)); err == nil { + serverName = gomatrixserverlib.ServerName(serverNameOnly) + } + // Check if the hostname is an IPv4 or IPv6 literal. We cheat here by adding + // a /0 prefix length just to trick ParseCIDR into working. If we find that + // the server is an IP literal and we don't allow those then stop straight + // away. + if _, _, err := net.ParseCIDR(fmt.Sprintf("%s/0", serverName)); err == nil { + if !acls.AllowIPLiterals { + return true + } + } + // Check if the hostname matches one of the denied regexes. If it does then + // the server is banned from the room. + for _, expr := range acls.deniedRegexes { + if expr.MatchString(string(serverName)) { + return true + } + } + // Check if the hostname matches one of the allowed regexes. If it does then + // the server is NOT banned from the room. + for _, expr := range acls.allowedRegexes { + if expr.MatchString(string(serverName)) { + return false + } + } + // If we've got to this point then we haven't matched any regexes or an IP + // hostname if disallowed. The spec calls for default-deny here. + return true +} diff --git a/roomserver/acls/acls_test.go b/roomserver/acls/acls_test.go new file mode 100644 index 000000000..9fb6a5581 --- /dev/null +++ b/roomserver/acls/acls_test.go @@ -0,0 +1,105 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package acls + +import ( + "regexp" + "testing" +) + +func TestOpenACLsWithBlacklist(t *testing.T) { + roomID := "!test:test.com" + allowRegex, err := compileACLRegex("*") + if err != nil { + t.Fatalf(err.Error()) + } + denyRegex, err := compileACLRegex("foo.com") + if err != nil { + t.Fatalf(err.Error()) + } + + acls := ServerACLs{ + acls: make(map[string]*serverACL), + } + + acls.acls[roomID] = &serverACL{ + ServerACL: ServerACL{ + AllowIPLiterals: true, + }, + allowedRegexes: []*regexp.Regexp{allowRegex}, + deniedRegexes: []*regexp.Regexp{denyRegex}, + } + + if acls.IsServerBannedFromRoom("1.2.3.4", roomID) { + t.Fatal("Expected 1.2.3.4 to be allowed but wasn't") + } + if acls.IsServerBannedFromRoom("1.2.3.4:2345", roomID) { + t.Fatal("Expected 1.2.3.4:2345 to be allowed but wasn't") + } + if !acls.IsServerBannedFromRoom("foo.com", roomID) { + t.Fatal("Expected foo.com to be banned but wasn't") + } + if !acls.IsServerBannedFromRoom("foo.com:3456", roomID) { + t.Fatal("Expected foo.com:3456 to be banned but wasn't") + } + if acls.IsServerBannedFromRoom("bar.com", roomID) { + t.Fatal("Expected bar.com to be allowed but wasn't") + } + if acls.IsServerBannedFromRoom("bar.com:4567", roomID) { + t.Fatal("Expected bar.com:4567 to be allowed but wasn't") + } +} + +func TestDefaultACLsWithWhitelist(t *testing.T) { + roomID := "!test:test.com" + allowRegex, err := compileACLRegex("foo.com") + if err != nil { + t.Fatalf(err.Error()) + } + + acls := ServerACLs{ + acls: make(map[string]*serverACL), + } + + acls.acls[roomID] = &serverACL{ + ServerACL: ServerACL{ + AllowIPLiterals: false, + }, + allowedRegexes: []*regexp.Regexp{allowRegex}, + deniedRegexes: []*regexp.Regexp{}, + } + + if !acls.IsServerBannedFromRoom("1.2.3.4", roomID) { + t.Fatal("Expected 1.2.3.4 to be banned but wasn't") + } + if !acls.IsServerBannedFromRoom("1.2.3.4:2345", roomID) { + t.Fatal("Expected 1.2.3.4:2345 to be banned but wasn't") + } + if acls.IsServerBannedFromRoom("foo.com", roomID) { + t.Fatal("Expected foo.com to be allowed but wasn't") + } + if acls.IsServerBannedFromRoom("foo.com:3456", roomID) { + t.Fatal("Expected foo.com:3456 to be allowed but wasn't") + } + if !acls.IsServerBannedFromRoom("bar.com", roomID) { + t.Fatal("Expected bar.com to be allowed but wasn't") + } + if !acls.IsServerBannedFromRoom("baz.com", roomID) { + t.Fatal("Expected baz.com to be allowed but wasn't") + } + if !acls.IsServerBannedFromRoom("qux.com:4567", roomID) { + t.Fatal("Expected qux.com:4567 to be allowed but wasn't") + } +} diff --git a/roomserver/api/api.go b/roomserver/api/api.go index 0fe30b8b5..96bdc767e 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -106,6 +106,20 @@ type RoomserverInternalAPI interface { response *QueryStateAndAuthChainResponse, ) error + // QueryCurrentState retrieves the requested state events. If state events are not found, they will be missing from + // the response. + QueryCurrentState(ctx context.Context, req *QueryCurrentStateRequest, res *QueryCurrentStateResponse) error + // QueryRoomsForUser retrieves a list of room IDs matching the given query. + QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error + // QueryBulkStateContent does a bulk query for state event content in the given rooms. + QueryBulkStateContent(ctx context.Context, req *QueryBulkStateContentRequest, res *QueryBulkStateContentResponse) error + // QuerySharedUsers returns a list of users who share at least 1 room in common with the given user. + QuerySharedUsers(ctx context.Context, req *QuerySharedUsersRequest, res *QuerySharedUsersResponse) error + // QueryKnownUsers returns a list of users that we know about from our joined rooms. + QueryKnownUsers(ctx context.Context, req *QueryKnownUsersRequest, res *QueryKnownUsersResponse) error + // QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs. + QueryServerBannedFromRoom(ctx context.Context, req *QueryServerBannedFromRoomRequest, res *QueryServerBannedFromRoomResponse) error + // Query a given amount (or less) of events prior to a given set of events. PerformBackfill( ctx context.Context, diff --git a/roomserver/api/api_trace.go b/roomserver/api/api_trace.go index 9b53aa88c..25da2e8e0 100644 --- a/roomserver/api/api_trace.go +++ b/roomserver/api/api_trace.go @@ -236,6 +236,47 @@ func (t *RoomserverInternalAPITrace) RemoveRoomAlias( return err } +func (t *RoomserverInternalAPITrace) QueryCurrentState(ctx context.Context, req *QueryCurrentStateRequest, res *QueryCurrentStateResponse) error { + err := t.Impl.QueryCurrentState(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("QueryCurrentState req=%+v res=%+v", js(req), js(res)) + return err +} + +// QueryRoomsForUser retrieves a list of room IDs matching the given query. +func (t *RoomserverInternalAPITrace) QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error { + err := t.Impl.QueryRoomsForUser(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("QueryRoomsForUser req=%+v res=%+v", js(req), js(res)) + return err +} + +// QueryBulkStateContent does a bulk query for state event content in the given rooms. +func (t *RoomserverInternalAPITrace) QueryBulkStateContent(ctx context.Context, req *QueryBulkStateContentRequest, res *QueryBulkStateContentResponse) error { + err := t.Impl.QueryBulkStateContent(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("QueryBulkStateContent req=%+v res=%+v", js(req), js(res)) + return err +} + +// QuerySharedUsers returns a list of users who share at least 1 room in common with the given user. +func (t *RoomserverInternalAPITrace) QuerySharedUsers(ctx context.Context, req *QuerySharedUsersRequest, res *QuerySharedUsersResponse) error { + err := t.Impl.QuerySharedUsers(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("QuerySharedUsers req=%+v res=%+v", js(req), js(res)) + return err +} + +// QueryKnownUsers returns a list of users that we know about from our joined rooms. +func (t *RoomserverInternalAPITrace) QueryKnownUsers(ctx context.Context, req *QueryKnownUsersRequest, res *QueryKnownUsersResponse) error { + err := t.Impl.QueryKnownUsers(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("QueryKnownUsers req=%+v res=%+v", js(req), js(res)) + return err +} + +// QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs. +func (t *RoomserverInternalAPITrace) QueryServerBannedFromRoom(ctx context.Context, req *QueryServerBannedFromRoomRequest, res *QueryServerBannedFromRoomResponse) error { + err := t.Impl.QueryServerBannedFromRoom(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("QueryServerBannedFromRoom req=%+v res=%+v", js(req), js(res)) + return err +} + func js(thing interface{}) string { b, err := json.Marshal(thing) if err != nil { diff --git a/roomserver/api/query.go b/roomserver/api/query.go index 4e1d09c30..d0d0474d8 100644 --- a/roomserver/api/query.go +++ b/roomserver/api/query.go @@ -17,6 +17,11 @@ package api import ( + "encoding/json" + "fmt" + "strings" + + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/gomatrixserverlib" ) @@ -225,3 +230,102 @@ type QueryPublishedRoomsResponse struct { // The list of published rooms. RoomIDs []string } + +type QuerySharedUsersRequest struct { + UserID string + ExcludeRoomIDs []string + IncludeRoomIDs []string +} + +type QuerySharedUsersResponse struct { + UserIDsToCount map[string]int +} + +type QueryRoomsForUserRequest struct { + UserID string + // The desired membership of the user. If this is the empty string then no rooms are returned. + WantMembership string +} + +type QueryRoomsForUserResponse struct { + RoomIDs []string +} + +type QueryBulkStateContentRequest struct { + // Returns state events in these rooms + RoomIDs []string + // If true, treats the '*' StateKey as "all state events of this type" rather than a literal value of '*' + AllowWildcards bool + // The state events to return. Only a small subset of tuples are allowed in this request as only certain events + // have their content fields extracted. Specifically, the tuple Type must be one of: + // m.room.avatar + // m.room.create + // m.room.canonical_alias + // m.room.guest_access + // m.room.history_visibility + // m.room.join_rules + // m.room.member + // m.room.name + // m.room.topic + // Any other tuple type will result in the query failing. + StateTuples []gomatrixserverlib.StateKeyTuple +} +type QueryBulkStateContentResponse struct { + // map of room ID -> tuple -> content_value + Rooms map[string]map[gomatrixserverlib.StateKeyTuple]string +} + +type QueryCurrentStateRequest struct { + RoomID string + StateTuples []gomatrixserverlib.StateKeyTuple +} + +type QueryCurrentStateResponse struct { + StateEvents map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent +} + +type QueryKnownUsersRequest struct { + UserID string `json:"user_id"` + SearchString string `json:"search_string"` + Limit int `json:"limit"` +} + +type QueryKnownUsersResponse struct { + Users []authtypes.FullyQualifiedProfile `json:"profiles"` +} + +type QueryServerBannedFromRoomRequest struct { + ServerName gomatrixserverlib.ServerName `json:"server_name"` + RoomID string `json:"room_id"` +} + +type QueryServerBannedFromRoomResponse struct { + Banned bool `json:"banned"` +} + +// MarshalJSON stringifies the StateKeyTuple keys so they can be sent over the wire in HTTP API mode. +func (r *QueryCurrentStateResponse) MarshalJSON() ([]byte, error) { + se := make(map[string]*gomatrixserverlib.HeaderedEvent, len(r.StateEvents)) + for k, v := range r.StateEvents { + // use 0x1F (unit separator) as the delimiter between type/state key, + se[fmt.Sprintf("%s\x1F%s", k.EventType, k.StateKey)] = v + } + return json.Marshal(se) +} + +func (r *QueryCurrentStateResponse) UnmarshalJSON(data []byte) error { + res := make(map[string]*gomatrixserverlib.HeaderedEvent) + err := json.Unmarshal(data, &res) + if err != nil { + return err + } + r.StateEvents = make(map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent, len(res)) + for k, v := range res { + fields := strings.Split(k, "\x1F") + r.StateEvents[gomatrixserverlib.StateKeyTuple{ + EventType: fields[0], + StateKey: fields[1], + }] = v + } + return nil +} diff --git a/roomserver/api/wrapper.go b/roomserver/api/wrapper.go index 16f5e8e18..82a4a5719 100644 --- a/roomserver/api/wrapper.go +++ b/roomserver/api/wrapper.go @@ -133,3 +133,102 @@ func GetEvent(ctx context.Context, rsAPI RoomserverInternalAPI, eventID string) } return &res.Events[0] } + +// GetStateEvent returns the current state event in the room or nil. +func GetStateEvent(ctx context.Context, rsAPI RoomserverInternalAPI, roomID string, tuple gomatrixserverlib.StateKeyTuple) *gomatrixserverlib.HeaderedEvent { + var res QueryCurrentStateResponse + err := rsAPI.QueryCurrentState(ctx, &QueryCurrentStateRequest{ + RoomID: roomID, + StateTuples: []gomatrixserverlib.StateKeyTuple{tuple}, + }, &res) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("Failed to QueryCurrentState") + return nil + } + ev, ok := res.StateEvents[tuple] + if ok { + return ev + } + return nil +} + +// IsServerBannedFromRoom returns whether the server is banned from a room by server ACLs. +func IsServerBannedFromRoom(ctx context.Context, rsAPI RoomserverInternalAPI, roomID string, serverName gomatrixserverlib.ServerName) bool { + req := &QueryServerBannedFromRoomRequest{ + ServerName: serverName, + RoomID: roomID, + } + res := &QueryServerBannedFromRoomResponse{} + if err := rsAPI.QueryServerBannedFromRoom(ctx, req, res); err != nil { + util.GetLogger(ctx).WithError(err).Error("Failed to QueryServerBannedFromRoom") + return true + } + return res.Banned +} + +// PopulatePublicRooms extracts PublicRoom information for all the provided room IDs. The IDs are not checked to see if they are visible in the +// published room directory. +// due to lots of switches +// nolint:gocyclo +func PopulatePublicRooms(ctx context.Context, roomIDs []string, rsAPI RoomserverInternalAPI) ([]gomatrixserverlib.PublicRoom, error) { + avatarTuple := gomatrixserverlib.StateKeyTuple{EventType: "m.room.avatar", StateKey: ""} + nameTuple := gomatrixserverlib.StateKeyTuple{EventType: "m.room.name", StateKey: ""} + canonicalTuple := gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomCanonicalAlias, StateKey: ""} + topicTuple := gomatrixserverlib.StateKeyTuple{EventType: "m.room.topic", StateKey: ""} + guestTuple := gomatrixserverlib.StateKeyTuple{EventType: "m.room.guest_access", StateKey: ""} + visibilityTuple := gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomHistoryVisibility, StateKey: ""} + joinRuleTuple := gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomJoinRules, StateKey: ""} + + var stateRes QueryBulkStateContentResponse + err := rsAPI.QueryBulkStateContent(ctx, &QueryBulkStateContentRequest{ + RoomIDs: roomIDs, + AllowWildcards: true, + StateTuples: []gomatrixserverlib.StateKeyTuple{ + nameTuple, canonicalTuple, topicTuple, guestTuple, visibilityTuple, joinRuleTuple, avatarTuple, + {EventType: gomatrixserverlib.MRoomMember, StateKey: "*"}, + }, + }, &stateRes) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("QueryBulkStateContent failed") + return nil, err + } + chunk := make([]gomatrixserverlib.PublicRoom, len(roomIDs)) + i := 0 + for roomID, data := range stateRes.Rooms { + pub := gomatrixserverlib.PublicRoom{ + RoomID: roomID, + } + joinCount := 0 + var joinRule, guestAccess string + for tuple, contentVal := range data { + if tuple.EventType == gomatrixserverlib.MRoomMember && contentVal == "join" { + joinCount++ + continue + } + switch tuple { + case avatarTuple: + pub.AvatarURL = contentVal + case nameTuple: + pub.Name = contentVal + case topicTuple: + pub.Topic = contentVal + case canonicalTuple: + pub.CanonicalAlias = contentVal + case visibilityTuple: + pub.WorldReadable = contentVal == "world_readable" + // need both of these to determine whether guests can join + case joinRuleTuple: + joinRule = contentVal + case guestTuple: + guestAccess = contentVal + } + } + if joinRule == gomatrixserverlib.Public && guestAccess == "can_join" { + pub.GuestCanJoin = true + } + pub.JoinedMembersCount = joinCount + chunk[i] = pub + i++ + } + return chunk, nil +} diff --git a/roomserver/internal/api.go b/roomserver/internal/api.go index 93c0be77b..bdea650ea 100644 --- a/roomserver/internal/api.go +++ b/roomserver/internal/api.go @@ -7,6 +7,7 @@ import ( fsAPI "github.com/matrix-org/dendrite/federationsender/api" "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/roomserver/acls" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/internal/input" "github.com/matrix-org/dendrite/roomserver/internal/perform" @@ -46,8 +47,9 @@ func NewRoomserverAPI( ServerName: cfg.Matrix.ServerName, KeyRing: keyRing, Queryer: &query.Queryer{ - DB: roomserverDB, - Cache: caches, + DB: roomserverDB, + Cache: caches, + ServerACLs: acls.NewServerACLs(roomserverDB), }, Inputer: &input.Inputer{ DB: roomserverDB, diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index b2799aefb..f76c93166 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -16,9 +16,12 @@ package query import ( "context" + "errors" "fmt" + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/roomserver/acls" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/internal/helpers" "github.com/matrix-org/dendrite/roomserver/state" @@ -31,8 +34,9 @@ import ( ) type Queryer struct { - DB storage.Database - Cache caching.RoomServerCaches + DB storage.Database + Cache caching.RoomServerCaches + ServerACLs *acls.ServerACLs } // QueryLatestEventsAndState implements api.RoomserverInternalAPI @@ -502,3 +506,97 @@ func (r *Queryer) QueryPublishedRooms( res.RoomIDs = rooms return nil } + +func (r *Queryer) QueryCurrentState(ctx context.Context, req *api.QueryCurrentStateRequest, res *api.QueryCurrentStateResponse) error { + res.StateEvents = make(map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent) + for _, tuple := range req.StateTuples { + ev, err := r.DB.GetStateEvent(ctx, req.RoomID, tuple.EventType, tuple.StateKey) + if err != nil { + return err + } + if ev != nil { + res.StateEvents[tuple] = ev + } + } + return nil +} + +func (r *Queryer) QueryRoomsForUser(ctx context.Context, req *api.QueryRoomsForUserRequest, res *api.QueryRoomsForUserResponse) error { + roomIDs, err := r.DB.GetRoomsByMembership(ctx, req.UserID, req.WantMembership) + if err != nil { + return err + } + res.RoomIDs = roomIDs + return nil +} + +func (r *Queryer) QueryKnownUsers(ctx context.Context, req *api.QueryKnownUsersRequest, res *api.QueryKnownUsersResponse) error { + users, err := r.DB.GetKnownUsers(ctx, req.UserID, req.SearchString, req.Limit) + if err != nil { + return err + } + for _, user := range users { + res.Users = append(res.Users, authtypes.FullyQualifiedProfile{ + UserID: user, + }) + } + return nil +} + +func (r *Queryer) QueryBulkStateContent(ctx context.Context, req *api.QueryBulkStateContentRequest, res *api.QueryBulkStateContentResponse) error { + events, err := r.DB.GetBulkStateContent(ctx, req.RoomIDs, req.StateTuples, req.AllowWildcards) + if err != nil { + return err + } + res.Rooms = make(map[string]map[gomatrixserverlib.StateKeyTuple]string) + for _, ev := range events { + if res.Rooms[ev.RoomID] == nil { + res.Rooms[ev.RoomID] = make(map[gomatrixserverlib.StateKeyTuple]string) + } + room := res.Rooms[ev.RoomID] + room[gomatrixserverlib.StateKeyTuple{ + EventType: ev.EventType, + StateKey: ev.StateKey, + }] = ev.ContentValue + res.Rooms[ev.RoomID] = room + } + return nil +} + +func (r *Queryer) QuerySharedUsers(ctx context.Context, req *api.QuerySharedUsersRequest, res *api.QuerySharedUsersResponse) error { + roomIDs, err := r.DB.GetRoomsByMembership(ctx, req.UserID, "join") + if err != nil { + return err + } + roomIDs = append(roomIDs, req.IncludeRoomIDs...) + excludeMap := make(map[string]bool) + for _, roomID := range req.ExcludeRoomIDs { + excludeMap[roomID] = true + } + // filter out excluded rooms + j := 0 + for i := range roomIDs { + // move elements to include to the beginning of the slice + // then trim elements on the right + if !excludeMap[roomIDs[i]] { + roomIDs[j] = roomIDs[i] + j++ + } + } + roomIDs = roomIDs[:j] + + users, err := r.DB.JoinedUsersSetInRooms(ctx, roomIDs) + if err != nil { + return err + } + res.UserIDsToCount = users + return nil +} + +func (r *Queryer) QueryServerBannedFromRoom(ctx context.Context, req *api.QueryServerBannedFromRoomRequest, res *api.QueryServerBannedFromRoomResponse) error { + if r.ServerACLs == nil { + return errors.New("no server ACL tracking") + } + res.Banned = r.ServerACLs.IsServerBannedFromRoom(req.ServerName, req.RoomID) + return nil +} diff --git a/roomserver/inthttp/client.go b/roomserver/inthttp/client.go index 1657bcdeb..b414b0d8c 100644 --- a/roomserver/inthttp/client.go +++ b/roomserver/inthttp/client.go @@ -43,6 +43,12 @@ const ( RoomserverQueryRoomVersionCapabilitiesPath = "/roomserver/queryRoomVersionCapabilities" RoomserverQueryRoomVersionForRoomPath = "/roomserver/queryRoomVersionForRoom" RoomserverQueryPublishedRoomsPath = "/roomserver/queryPublishedRooms" + RoomserverQueryCurrentStatePath = "/roomserver/queryCurrentState" + RoomserverQueryRoomsForUserPath = "/roomserver/queryRoomsForUser" + RoomserverQueryBulkStateContentPath = "/roomserver/queryBulkStateContent" + RoomserverQuerySharedUsersPath = "/roomserver/querySharedUsers" + RoomserverQueryKnownUsersPath = "/roomserver/queryKnownUsers" + RoomserverQueryServerBannedFromRoomPath = "/roomserver/queryServerBannedFromRoom" ) type httpRoomserverInternalAPI struct { @@ -371,3 +377,69 @@ func (h *httpRoomserverInternalAPI) QueryRoomVersionForRoom( } return err } + +func (h *httpRoomserverInternalAPI) QueryCurrentState( + ctx context.Context, + request *api.QueryCurrentStateRequest, + response *api.QueryCurrentStateResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryCurrentState") + defer span.Finish() + + apiURL := h.roomserverURL + RoomserverQueryCurrentStatePath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +func (h *httpRoomserverInternalAPI) QueryRoomsForUser( + ctx context.Context, + request *api.QueryRoomsForUserRequest, + response *api.QueryRoomsForUserResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryRoomsForUser") + defer span.Finish() + + apiURL := h.roomserverURL + RoomserverQueryRoomsForUserPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +func (h *httpRoomserverInternalAPI) QueryBulkStateContent( + ctx context.Context, + request *api.QueryBulkStateContentRequest, + response *api.QueryBulkStateContentResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryBulkStateContent") + defer span.Finish() + + apiURL := h.roomserverURL + RoomserverQueryBulkStateContentPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +func (h *httpRoomserverInternalAPI) QuerySharedUsers( + ctx context.Context, req *api.QuerySharedUsersRequest, res *api.QuerySharedUsersResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QuerySharedUsers") + defer span.Finish() + + apiURL := h.roomserverURL + RoomserverQuerySharedUsersPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +} + +func (h *httpRoomserverInternalAPI) QueryKnownUsers( + ctx context.Context, req *api.QueryKnownUsersRequest, res *api.QueryKnownUsersResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryKnownUsers") + defer span.Finish() + + apiURL := h.roomserverURL + RoomserverQueryKnownUsersPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +} + +func (h *httpRoomserverInternalAPI) QueryServerBannedFromRoom( + ctx context.Context, req *api.QueryServerBannedFromRoomRequest, res *api.QueryServerBannedFromRoomResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryServerBannedFromRoom") + defer span.Finish() + + apiURL := h.roomserverURL + RoomserverQueryServerBannedFromRoomPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +} diff --git a/roomserver/inthttp/server.go b/roomserver/inthttp/server.go index 0ac36a2a4..ebfb296d8 100644 --- a/roomserver/inthttp/server.go +++ b/roomserver/inthttp/server.go @@ -312,4 +312,82 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) + internalAPIMux.Handle(RoomserverQueryCurrentStatePath, + httputil.MakeInternalAPI("queryCurrentState", func(req *http.Request) util.JSONResponse { + request := api.QueryCurrentStateRequest{} + response := api.QueryCurrentStateResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := r.QueryCurrentState(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle(RoomserverQueryRoomsForUserPath, + httputil.MakeInternalAPI("queryRoomsForUser", func(req *http.Request) util.JSONResponse { + request := api.QueryRoomsForUserRequest{} + response := api.QueryRoomsForUserResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := r.QueryRoomsForUser(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle(RoomserverQueryBulkStateContentPath, + httputil.MakeInternalAPI("queryBulkStateContent", func(req *http.Request) util.JSONResponse { + request := api.QueryBulkStateContentRequest{} + response := api.QueryBulkStateContentResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := r.QueryBulkStateContent(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle(RoomserverQuerySharedUsersPath, + httputil.MakeInternalAPI("querySharedUsers", func(req *http.Request) util.JSONResponse { + request := api.QuerySharedUsersRequest{} + response := api.QuerySharedUsersResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := r.QuerySharedUsers(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle(RoomserverQuerySharedUsersPath, + httputil.MakeInternalAPI("queryKnownUsers", func(req *http.Request) util.JSONResponse { + request := api.QueryKnownUsersRequest{} + response := api.QueryKnownUsersResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := r.QueryKnownUsers(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle(RoomserverQueryServerBannedFromRoomPath, + httputil.MakeInternalAPI("queryServerBannedFromRoom", func(req *http.Request) util.JSONResponse { + request := api.QueryServerBannedFromRoomRequest{} + response := api.QueryServerBannedFromRoomResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := r.QueryServerBannedFromRoom(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) } diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index ef7a9f090..c4119f7ed 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -17,6 +17,7 @@ package storage import ( "context" + "github.com/matrix-org/dendrite/currentstateserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/types" @@ -138,4 +139,22 @@ type Database interface { PublishRoom(ctx context.Context, roomID string, publish bool) error // Returns a list of room IDs for rooms which are published. GetPublishedRooms(ctx context.Context) ([]string, error) + + // TODO: factor out - from currentstateserver + + // GetStateEvent returns the state event of a given type for a given room with a given state key + // If no event could be found, returns nil + // If there was an issue during the retrieval, returns an error + GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) + // GetRoomsByMembership returns a list of room IDs matching the provided membership and user ID (as state_key). + GetRoomsByMembership(ctx context.Context, userID, membership string) ([]string, error) + // GetBulkStateContent returns all state events which match a given room ID and a given state key tuple. Both must be satisfied for a match. + // If a tuple has the StateKey of '*' and allowWildcards=true then all state events with the EventType should be returned. + GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error) + // JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear. + JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) + // GetKnownUsers searches all users that userID knows about. + GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error) + // GetKnownRooms returns a list of all rooms we know about. + GetKnownRooms(ctx context.Context) ([]string, error) } diff --git a/roomserver/storage/postgres/membership_table.go b/roomserver/storage/postgres/membership_table.go index 13cef638f..0799647e9 100644 --- a/roomserver/storage/postgres/membership_table.go +++ b/roomserver/storage/postgres/membership_table.go @@ -99,6 +99,9 @@ const updateMembershipSQL = "" + "UPDATE roomserver_membership SET sender_nid = $3, membership_nid = $4, event_nid = $5" + " WHERE room_nid = $1 AND target_nid = $2" +const selectRoomsWithMembershipSQL = "" + + "SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2" + type membershipStatements struct { insertMembershipStmt *sql.Stmt selectMembershipForUpdateStmt *sql.Stmt @@ -108,6 +111,7 @@ type membershipStatements struct { selectMembershipsFromRoomStmt *sql.Stmt selectLocalMembershipsFromRoomStmt *sql.Stmt updateMembershipStmt *sql.Stmt + selectRoomsWithMembershipStmt *sql.Stmt } func NewPostgresMembershipTable(db *sql.DB) (tables.Membership, error) { @@ -126,6 +130,7 @@ func NewPostgresMembershipTable(db *sql.DB) (tables.Membership, error) { {&s.selectMembershipsFromRoomStmt, selectMembershipsFromRoomSQL}, {&s.selectLocalMembershipsFromRoomStmt, selectLocalMembershipsFromRoomSQL}, {&s.updateMembershipStmt, updateMembershipSQL}, + {&s.selectRoomsWithMembershipStmt, selectRoomsWithMembershipSQL}, }.Prepare(db) } @@ -222,3 +227,22 @@ func (s *membershipStatements) UpdateMembership( ) return err } + +func (s *membershipStatements) SelectRoomsWithMembership( + ctx context.Context, userID types.EventStateKeyNID, membershipState tables.MembershipState, +) ([]types.RoomNID, error) { + rows, err := s.selectRoomsWithMembershipStmt.QueryContext(ctx, membershipState, userID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectRoomsWithMembership: rows.close() failed") + var roomNIDs []types.RoomNID + for rows.Next() { + var roomNID types.RoomNID + if err := rows.Scan(&roomNID); err != nil { + return nil, err + } + roomNIDs = append(roomNIDs, roomNID) + } + return roomNIDs, nil +} diff --git a/roomserver/storage/postgres/rooms_table.go b/roomserver/storage/postgres/rooms_table.go index 13c8e703d..9d359146a 100644 --- a/roomserver/storage/postgres/rooms_table.go +++ b/roomserver/storage/postgres/rooms_table.go @@ -21,6 +21,7 @@ import ( "errors" "github.com/lib/pq" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" @@ -74,6 +75,12 @@ const selectRoomVersionForRoomNIDSQL = "" + const selectRoomInfoSQL = "" + "SELECT room_version, room_nid, state_snapshot_nid, latest_event_nids FROM roomserver_rooms WHERE room_id = $1" +const selectRoomIDsSQL = "" + + "SELECT room_id FROM roomserver_rooms" + +const bulkSelectRoomIDsSQL = "" + + "SELECT room_id FROM roomserver_rooms WHERE room_nid IN ($1)" + type roomStatements struct { insertRoomNIDStmt *sql.Stmt selectRoomNIDStmt *sql.Stmt @@ -82,6 +89,8 @@ type roomStatements struct { updateLatestEventNIDsStmt *sql.Stmt selectRoomVersionForRoomNIDStmt *sql.Stmt selectRoomInfoStmt *sql.Stmt + selectRoomIDsStmt *sql.Stmt + bulkSelectRoomIDsStmt *sql.Stmt } func NewPostgresRoomsTable(db *sql.DB) (tables.Rooms, error) { @@ -98,9 +107,27 @@ func NewPostgresRoomsTable(db *sql.DB) (tables.Rooms, error) { {&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL}, {&s.selectRoomVersionForRoomNIDStmt, selectRoomVersionForRoomNIDSQL}, {&s.selectRoomInfoStmt, selectRoomInfoSQL}, + {&s.selectRoomIDsStmt, selectRoomIDsSQL}, + {&s.bulkSelectRoomIDsStmt, bulkSelectRoomIDsSQL}, }.Prepare(db) } +func (s *roomStatements) SelectRoomIDs(ctx context.Context) ([]string, error) { + rows, err := s.selectRoomIDsStmt.QueryContext(ctx) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectRoomIDsStmt: rows.close() failed") + var roomIDs []string + for rows.Next() { + var roomID string + if err = rows.Scan(&roomID); err != nil { + return nil, err + } + roomIDs = append(roomIDs, roomID) + } + return roomIDs, nil +} func (s *roomStatements) InsertRoomNID( ctx context.Context, txn *sql.Tx, roomID string, roomVersion gomatrixserverlib.RoomVersion, @@ -197,3 +224,24 @@ func (s *roomStatements) SelectRoomVersionForRoomNID( } return roomVersion, err } + +func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) { + var array pq.Int64Array + for _, nid := range roomNIDs { + array = append(array, int64(nid)) + } + rows, err := s.bulkSelectRoomIDsStmt.QueryContext(ctx, array) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomIDsStmt: rows.close() failed") + var roomIDs []string + for rows.Next() { + var roomID string + if err = rows.Scan(&roomID); err != nil { + return nil, err + } + roomIDs = append(roomIDs, roomID) + } + return roomIDs, nil +} diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 6e0ebd2c2..5c447d66f 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -6,6 +6,7 @@ import ( "encoding/json" "fmt" + csstables "github.com/matrix-org/dendrite/currentstateserver/storage/tables" "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/api" @@ -711,3 +712,82 @@ func (d *Database) loadEvent(ctx context.Context, eventID string) *types.Event { } return &evs[0] } + +// GetStateEvent returns the current state event of a given type for a given room with a given state key +// If no event could be found, returns nil +// If there was an issue during the retrieval, returns an error +func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) { + /* + roomInfo, err := d.RoomInfo(ctx, roomID) + if err != nil { + return nil, err + } + eventTypeNID, err := d.EventTypesTable.SelectEventTypeNID(ctx, nil, evType) + if err != nil { + return nil, err + } + stateKeyNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, nil, stateKey) + if err != nil { + return nil, err + } + blockNIDs, err := d.StateSnapshotTable.BulkSelectStateBlockNIDs(ctx, []types.StateSnapshotNID{roomInfo.StateSnapshotNID}) + if err != nil { + return nil, err + } + */ + return nil, nil +} + +// GetRoomsByMembership returns a list of room IDs matching the provided membership and user ID (as state_key). +func (d *Database) GetRoomsByMembership(ctx context.Context, userID, membership string) ([]string, error) { + var membershipState tables.MembershipState + switch membership { + case "join": + membershipState = tables.MembershipStateJoin + case "invite": + membershipState = tables.MembershipStateInvite + case "leave": + membershipState = tables.MembershipStateLeaveOrBan + case "ban": + membershipState = tables.MembershipStateLeaveOrBan + default: + return nil, fmt.Errorf("GetRoomsByMembership: invalid membership %s", membership) + } + stateKeyNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, nil, userID) + if err != nil { + return nil, fmt.Errorf("GetRoomsByMembership: cannot map user ID to state key NID: %w", err) + } + roomNIDs, err := d.MembershipTable.SelectRoomsWithMembership(ctx, stateKeyNID, membershipState) + if err != nil { + return nil, err + } + roomIDs, err := d.RoomsTable.BulkSelectRoomIDs(ctx, roomNIDs) + if err != nil { + return nil, err + } + if len(roomIDs) != len(roomNIDs) { + return nil, fmt.Errorf("GetRoomsByMembership: missing room IDs, got %d want %d", len(roomIDs), len(roomNIDs)) + } + return roomIDs, nil +} + +// GetBulkStateContent returns all state events which match a given room ID and a given state key tuple. Both must be satisfied for a match. +// If a tuple has the StateKey of '*' and allowWildcards=true then all state events with the EventType should be returned. +func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]csstables.StrippedEvent, error) { + return nil, fmt.Errorf("not implemented yet") +} + +// JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear. +func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) { + return nil, fmt.Errorf("not implemented yet") +} + +// GetKnownUsers searches all users that userID knows about. +func (d *Database) GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error) { + return nil, fmt.Errorf("not implemented yet") +} + +// GetKnownRooms returns a list of all rooms we know about. +func (d *Database) GetKnownRooms(ctx context.Context) ([]string, error) { + return d.RoomsTable.SelectRoomIDs(ctx) +} diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go index b3ee69c00..e850c80bb 100644 --- a/roomserver/storage/sqlite3/membership_table.go +++ b/roomserver/storage/sqlite3/membership_table.go @@ -75,6 +75,9 @@ const updateMembershipSQL = "" + "UPDATE roomserver_membership SET sender_nid = $1, membership_nid = $2, event_nid = $3" + " WHERE room_nid = $4 AND target_nid = $5" +const selectRoomsWithMembershipSQL = "" + + "SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2" + type membershipStatements struct { db *sql.DB insertMembershipStmt *sql.Stmt @@ -84,6 +87,7 @@ type membershipStatements struct { selectLocalMembershipsFromRoomAndMembershipStmt *sql.Stmt selectMembershipsFromRoomStmt *sql.Stmt selectLocalMembershipsFromRoomStmt *sql.Stmt + selectRoomsWithMembershipStmt *sql.Stmt updateMembershipStmt *sql.Stmt } @@ -105,6 +109,7 @@ func NewSqliteMembershipTable(db *sql.DB) (tables.Membership, error) { {&s.selectMembershipsFromRoomStmt, selectMembershipsFromRoomSQL}, {&s.selectLocalMembershipsFromRoomStmt, selectLocalMembershipsFromRoomSQL}, {&s.updateMembershipStmt, updateMembershipSQL}, + {&s.selectRoomsWithMembershipStmt, selectRoomsWithMembershipSQL}, }.Prepare(db) } @@ -203,3 +208,22 @@ func (s *membershipStatements) UpdateMembership( ) return err } + +func (s *membershipStatements) SelectRoomsWithMembership( + ctx context.Context, userID types.EventStateKeyNID, membershipState tables.MembershipState, +) ([]types.RoomNID, error) { + rows, err := s.selectRoomsWithMembershipStmt.QueryContext(ctx, membershipState, userID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectRoomsWithMembership: rows.close() failed") + var roomNIDs []types.RoomNID + for rows.Next() { + var roomNID types.RoomNID + if err := rows.Scan(&roomNID); err != nil { + return nil, err + } + roomNIDs = append(roomNIDs, roomNID) + } + return roomNIDs, nil +} diff --git a/roomserver/storage/sqlite3/rooms_table.go b/roomserver/storage/sqlite3/rooms_table.go index 4c1699d00..daacf86fa 100644 --- a/roomserver/storage/sqlite3/rooms_table.go +++ b/roomserver/storage/sqlite3/rooms_table.go @@ -21,7 +21,9 @@ import ( "encoding/json" "errors" "fmt" + "strings" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" @@ -64,6 +66,12 @@ const selectRoomVersionForRoomNIDSQL = "" + const selectRoomInfoSQL = "" + "SELECT room_version, room_nid, state_snapshot_nid, latest_event_nids FROM roomserver_rooms WHERE room_id = $1" +const selectRoomIDsSQL = "" + + "SELECT room_id FROM roomserver_rooms" + +const bulkSelectRoomIDsSQL = "" + + "SELECT room_id FROM roomserver_rooms WHERE room_nid IN ($1)" + type roomStatements struct { db *sql.DB insertRoomNIDStmt *sql.Stmt @@ -73,6 +81,7 @@ type roomStatements struct { updateLatestEventNIDsStmt *sql.Stmt selectRoomVersionForRoomNIDStmt *sql.Stmt selectRoomInfoStmt *sql.Stmt + selectRoomIDsStmt *sql.Stmt } func NewSqliteRoomsTable(db *sql.DB) (tables.Rooms, error) { @@ -91,9 +100,27 @@ func NewSqliteRoomsTable(db *sql.DB) (tables.Rooms, error) { {&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL}, {&s.selectRoomVersionForRoomNIDStmt, selectRoomVersionForRoomNIDSQL}, {&s.selectRoomInfoStmt, selectRoomInfoSQL}, + {&s.selectRoomIDsStmt, selectRoomIDsSQL}, }.Prepare(db) } +func (s *roomStatements) SelectRoomIDs(ctx context.Context) ([]string, error) { + rows, err := s.selectRoomIDsStmt.QueryContext(ctx) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectRoomIDsStmt: rows.close() failed") + var roomIDs []string + for rows.Next() { + var roomID string + if err = rows.Scan(&roomID); err != nil { + return nil, err + } + roomIDs = append(roomIDs, roomID) + } + return roomIDs, nil +} + func (s *roomStatements) SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) { var info types.RoomInfo var latestNIDsJSON string @@ -203,3 +230,25 @@ func (s *roomStatements) SelectRoomVersionForRoomNID( } return roomVersion, err } + +func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) { + iRoomNIDs := make([]interface{}, len(roomNIDs)) + for i, v := range roomNIDs { + iRoomNIDs[i] = v + } + sqlQuery := strings.Replace(bulkSelectRoomIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1) + rows, err := s.db.QueryContext(ctx, sqlQuery, iRoomNIDs...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomIDsStmt: rows.close() failed") + var roomIDs []string + for rows.Next() { + var roomID string + if err = rows.Scan(&roomID); err != nil { + return nil, err + } + roomIDs = append(roomIDs, roomID) + } + return roomIDs, nil +} diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index c599dd3fe..126c27b57 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -65,6 +65,8 @@ type Rooms interface { UpdateLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID, stateSnapshotNID types.StateSnapshotNID) error SelectRoomVersionForRoomNID(ctx context.Context, roomNID types.RoomNID) (gomatrixserverlib.RoomVersion, error) SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) + SelectRoomIDs(ctx context.Context) ([]string, error) + BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) } type Transactions interface { @@ -120,6 +122,7 @@ type Membership interface { SelectMembershipsFromRoom(ctx context.Context, roomNID types.RoomNID, localOnly bool) (eventNIDs []types.EventNID, err error) SelectMembershipsFromRoomAndMembership(ctx context.Context, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error) UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID) error + SelectRoomsWithMembership(ctx context.Context, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error) } type Published interface {