currentstate: Add QuerySharedUsers (#1217)

This will be used to determine who to send device list updates to. It
can also be used to determine who to send presence info to.
This commit is contained in:
Kegsay 2020-07-23 12:26:31 +01:00 committed by GitHub
parent cfeb1b2f42
commit 7b862384a7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 232 additions and 13 deletions

View file

@ -31,6 +31,16 @@ type CurrentStateInternalAPI interface {
QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error
// QueryBulkStateContent does a bulk query for state event content in the given rooms. // QueryBulkStateContent does a bulk query for state event content in the given rooms.
QueryBulkStateContent(ctx context.Context, req *QueryBulkStateContentRequest, res *QueryBulkStateContentResponse) error QueryBulkStateContent(ctx context.Context, req *QueryBulkStateContentRequest, res *QueryBulkStateContentResponse) error
// QuerySharedUsers returns a list of users who share at least 1 room in common with the given user.
QuerySharedUsers(ctx context.Context, req *QuerySharedUsersRequest, res *QuerySharedUsersResponse) error
}
type QuerySharedUsersRequest struct {
UserID string
}
type QuerySharedUsersResponse struct {
UserIDs []string
} }
type QueryRoomsForUserRequest struct { type QueryRoomsForUserRequest struct {

View file

@ -16,9 +16,11 @@ package currentstateserver
import ( import (
"context" "context"
"crypto/ed25519"
"encoding/json" "encoding/json"
"net/http" "net/http"
"reflect" "reflect"
"sort"
"testing" "testing"
"time" "time"
@ -178,3 +180,112 @@ func TestQueryCurrentState(t *testing.T) {
runCases(currStateAPI) runCases(currStateAPI)
}) })
} }
func mustMakeMembershipEvent(t *testing.T, roomID, userID, membership string) *roomserverAPI.OutputNewRoomEvent {
eb := gomatrixserverlib.EventBuilder{
RoomID: roomID,
Sender: userID,
StateKey: &userID,
Type: "m.room.member",
Content: []byte(`{"membership":"` + membership + `"}`),
}
_, pkey, err := ed25519.GenerateKey(nil)
if err != nil {
t.Fatalf("failed to make ed25519 key: %s", err)
}
roomVer := gomatrixserverlib.RoomVersionV5
ev, err := eb.Build(
time.Now(), gomatrixserverlib.ServerName("localhost"), gomatrixserverlib.KeyID("ed25519:test"),
pkey, roomVer,
)
if err != nil {
t.Fatalf("mustMakeMembershipEvent failed: %s", err)
}
return &roomserverAPI.OutputNewRoomEvent{
Event: ev.Headered(roomVer),
AddsStateEventIDs: []string{ev.EventID()},
}
}
// This test makes sure that QuerySharedUsers is returning the correct users for a range of sets.
func TestQuerySharedUsers(t *testing.T) {
currStateAPI, producer := MustMakeInternalAPI(t)
MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo:bar", "@alice:localhost", "join"))
MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo:bar", "@bob:localhost", "join"))
MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo2:bar", "@alice:localhost", "join"))
MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo2:bar", "@charlie:localhost", "join"))
MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo3:bar", "@alice:localhost", "join"))
MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo3:bar", "@bob:localhost", "join"))
MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo3:bar", "@dave:localhost", "leave"))
MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo4:bar", "@alice:localhost", "join"))
testCases := []struct {
req api.QuerySharedUsersRequest
wantRes api.QuerySharedUsersResponse
}{
// Simple case: sharing (A,B) (A,C) (A,B) (A) produces (A,B,C)
{
req: api.QuerySharedUsersRequest{
UserID: "@alice:localhost",
},
wantRes: api.QuerySharedUsersResponse{
UserIDs: []string{"@alice:localhost", "@bob:localhost", "@charlie:localhost"},
},
},
// Unknown user has no shared users
{
req: api.QuerySharedUsersRequest{
UserID: "@unknownuser:localhost",
},
wantRes: api.QuerySharedUsersResponse{
UserIDs: nil,
},
},
// left real user produces no shared users
{
req: api.QuerySharedUsersRequest{
UserID: "@dave:localhost",
},
wantRes: api.QuerySharedUsersResponse{
UserIDs: nil,
},
},
}
runCases := func(testAPI api.CurrentStateInternalAPI) {
for _, tc := range testCases {
var res api.QuerySharedUsersResponse
err := testAPI.QuerySharedUsers(context.Background(), &tc.req, &res)
if err != nil {
t.Errorf("QuerySharedUsers returned error: %s", err)
continue
}
sort.Strings(res.UserIDs)
sort.Strings(tc.wantRes.UserIDs)
if !reflect.DeepEqual(res.UserIDs, tc.wantRes.UserIDs) {
t.Errorf("QuerySharedUsers got users %+v want %+v", res.UserIDs, tc.wantRes.UserIDs)
}
}
}
t.Run("HTTP API", func(t *testing.T) {
router := mux.NewRouter().PathPrefix(httputil.InternalPathPrefix).Subrouter()
AddInternalRoutes(router, currStateAPI)
apiURL, cancel := test.ListenAndServe(t, router, false)
defer cancel()
httpAPI, err := inthttp.NewCurrentStateAPIClient(apiURL, &http.Client{})
if err != nil {
t.Fatalf("failed to create HTTP client")
}
runCases(httpAPI)
})
t.Run("Monolith", func(t *testing.T) {
runCases(currStateAPI)
})
}

View file

@ -68,3 +68,16 @@ func (a *CurrentStateInternalAPI) QueryBulkStateContent(ctx context.Context, req
} }
return nil return nil
} }
func (a *CurrentStateInternalAPI) QuerySharedUsers(ctx context.Context, req *api.QuerySharedUsersRequest, res *api.QuerySharedUsersResponse) error {
roomIDs, err := a.DB.GetRoomsByMembership(ctx, req.UserID, "join")
if err != nil {
return err
}
users, err := a.DB.JoinedUsersSetInRooms(ctx, roomIDs)
if err != nil {
return err
}
res.UserIDs = users
return nil
}

View file

@ -29,6 +29,7 @@ const (
QueryCurrentStatePath = "/currentstateserver/queryCurrentState" QueryCurrentStatePath = "/currentstateserver/queryCurrentState"
QueryRoomsForUserPath = "/currentstateserver/queryRoomsForUser" QueryRoomsForUserPath = "/currentstateserver/queryRoomsForUser"
QueryBulkStateContentPath = "/currentstateserver/queryBulkStateContent" QueryBulkStateContentPath = "/currentstateserver/queryBulkStateContent"
QuerySharedUsersPath = "/currentstateserver/querySharedUsers"
) )
// NewCurrentStateAPIClient creates a CurrentStateInternalAPI implemented by talking to a HTTP POST API. // NewCurrentStateAPIClient creates a CurrentStateInternalAPI implemented by talking to a HTTP POST API.
@ -86,3 +87,13 @@ func (h *httpCurrentStateInternalAPI) QueryBulkStateContent(
apiURL := h.apiURL + QueryBulkStateContentPath apiURL := h.apiURL + QueryBulkStateContentPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
} }
func (h *httpCurrentStateInternalAPI) QuerySharedUsers(
ctx context.Context, req *api.QuerySharedUsersRequest, res *api.QuerySharedUsersResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QuerySharedUsers")
defer span.Finish()
apiURL := h.apiURL + QuerySharedUsersPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
}

View file

@ -64,4 +64,17 @@ func AddRoutes(internalAPIMux *mux.Router, intAPI api.CurrentStateInternalAPI) {
return util.JSONResponse{Code: http.StatusOK, JSON: &response} return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}), }),
) )
internalAPIMux.Handle(QuerySharedUsersPath,
httputil.MakeInternalAPI("querySharedUsers", func(req *http.Request) util.JSONResponse {
request := api.QuerySharedUsersRequest{}
response := api.QuerySharedUsersResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := intAPI.QuerySharedUsers(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
} }

View file

@ -37,4 +37,6 @@ type Database interface {
GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error) GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error)
// Redact a state event // Redact a state event
RedactEvent(ctx context.Context, redactedEventID string, redactedBecause gomatrixserverlib.HeaderedEvent) error RedactEvent(ctx context.Context, redactedEventID string, redactedBecause gomatrixserverlib.HeaderedEvent) error
// JoinedUsersSetInRooms returns all joined users in the rooms given.
JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) ([]string, error)
} }

View file

@ -77,6 +77,9 @@ const selectBulkStateContentSQL = "" +
const selectBulkStateContentWildSQL = "" + const selectBulkStateContentWildSQL = "" +
"SELECT room_id, type, state_key, content_value FROM currentstate_current_room_state WHERE room_id = ANY($1) AND type = ANY($2)" "SELECT room_id, type, state_key, content_value FROM currentstate_current_room_state WHERE room_id = ANY($1) AND type = ANY($2)"
const selectJoinedUsersSetForRoomsSQL = "" +
"SELECT DISTINCT state_key FROM currentstate_current_room_state WHERE room_id = ANY($1) AND type = 'm.room.member' and content_value = 'join'"
type currentRoomStateStatements struct { type currentRoomStateStatements struct {
upsertRoomStateStmt *sql.Stmt upsertRoomStateStmt *sql.Stmt
deleteRoomStateByEventIDStmt *sql.Stmt deleteRoomStateByEventIDStmt *sql.Stmt
@ -85,6 +88,7 @@ type currentRoomStateStatements struct {
selectStateEventStmt *sql.Stmt selectStateEventStmt *sql.Stmt
selectBulkStateContentStmt *sql.Stmt selectBulkStateContentStmt *sql.Stmt
selectBulkStateContentWildStmt *sql.Stmt selectBulkStateContentWildStmt *sql.Stmt
selectJoinedUsersSetForRoomsStmt *sql.Stmt
} }
func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) { func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) {
@ -114,9 +118,29 @@ func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, erro
if s.selectBulkStateContentWildStmt, err = db.Prepare(selectBulkStateContentWildSQL); err != nil { if s.selectBulkStateContentWildStmt, err = db.Prepare(selectBulkStateContentWildSQL); err != nil {
return nil, err return nil, err
} }
if s.selectJoinedUsersSetForRoomsStmt, err = db.Prepare(selectJoinedUsersSetForRoomsSQL); err != nil {
return nil, err
}
return s, nil return s, nil
} }
func (s *currentRoomStateStatements) SelectJoinedUsersSetForRooms(ctx context.Context, roomIDs []string) ([]string, error) {
rows, err := s.selectJoinedUsersSetForRoomsStmt.QueryContext(ctx, pq.StringArray(roomIDs))
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedUsersSetForRooms: rows.close() failed")
var userIDs []string
for rows.Next() {
var userID string
if err := rows.Scan(&userID); err != nil {
return nil, err
}
userIDs = append(userIDs, userID)
}
return userIDs, rows.Err()
}
// SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state. // SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state.
func (s *currentRoomStateStatements) SelectRoomIDsWithMembership( func (s *currentRoomStateStatements) SelectRoomIDsWithMembership(
ctx context.Context, ctx context.Context,

View file

@ -85,3 +85,7 @@ func (d *Database) StoreStateEvents(ctx context.Context, addStateEvents []gomatr
func (d *Database) GetRoomsByMembership(ctx context.Context, userID, membership string) ([]string, error) { func (d *Database) GetRoomsByMembership(ctx context.Context, userID, membership string) ([]string, error) {
return d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, nil, userID, membership) return d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, nil, userID, membership)
} }
func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) ([]string, error) {
return d.CurrentRoomState.SelectJoinedUsersSetForRooms(ctx, roomIDs)
}

View file

@ -66,6 +66,9 @@ const selectBulkStateContentSQL = "" +
const selectBulkStateContentWildSQL = "" + const selectBulkStateContentWildSQL = "" +
"SELECT room_id, type, state_key, content_value FROM currentstate_current_room_state WHERE room_id IN ($1) AND type IN ($2)" "SELECT room_id, type, state_key, content_value FROM currentstate_current_room_state WHERE room_id IN ($1) AND type IN ($2)"
const selectJoinedUsersSetForRoomsSQL = "" +
"SELECT DISTINCT state_key FROM currentstate_current_room_state WHERE room_id IN ($1) AND type = 'm.room.member' and content_value = 'join'"
type currentRoomStateStatements struct { type currentRoomStateStatements struct {
db *sql.DB db *sql.DB
writer *sqlutil.TransactionWriter writer *sqlutil.TransactionWriter
@ -73,6 +76,7 @@ type currentRoomStateStatements struct {
deleteRoomStateByEventIDStmt *sql.Stmt deleteRoomStateByEventIDStmt *sql.Stmt
selectRoomIDsWithMembershipStmt *sql.Stmt selectRoomIDsWithMembershipStmt *sql.Stmt
selectStateEventStmt *sql.Stmt selectStateEventStmt *sql.Stmt
selectJoinedUsersSetForRoomsStmt *sql.Stmt
} }
func NewSqliteCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) { func NewSqliteCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) {
@ -96,9 +100,34 @@ func NewSqliteCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error)
if s.selectStateEventStmt, err = db.Prepare(selectStateEventSQL); err != nil { if s.selectStateEventStmt, err = db.Prepare(selectStateEventSQL); err != nil {
return nil, err return nil, err
} }
if s.selectJoinedUsersSetForRoomsStmt, err = db.Prepare(selectJoinedUsersSetForRoomsSQL); err != nil {
return nil, err
}
return s, nil return s, nil
} }
func (s *currentRoomStateStatements) SelectJoinedUsersSetForRooms(ctx context.Context, roomIDs []string) ([]string, error) {
iRoomIDs := make([]interface{}, len(roomIDs))
for i, v := range roomIDs {
iRoomIDs[i] = v
}
query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomIDs)), 1)
rows, err := s.db.QueryContext(ctx, query, iRoomIDs...)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedUsersSetForRooms: rows.close() failed")
var userIDs []string
for rows.Next() {
var userID string
if err := rows.Scan(&userID); err != nil {
return nil, err
}
userIDs = append(userIDs, userID)
}
return userIDs, rows.Err()
}
// SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state. // SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state.
func (s *currentRoomStateStatements) SelectRoomIDsWithMembership( func (s *currentRoomStateStatements) SelectRoomIDsWithMembership(
ctx context.Context, ctx context.Context,

View file

@ -36,6 +36,8 @@ type CurrentRoomState interface {
// SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state. // SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state.
SelectRoomIDsWithMembership(ctx context.Context, txn *sql.Tx, userID string, membership string) ([]string, error) SelectRoomIDsWithMembership(ctx context.Context, txn *sql.Tx, userID string, membership string) ([]string, error)
SelectBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]StrippedEvent, error) SelectBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]StrippedEvent, error)
// SelectJoinedUsersSetForRooms returns the set of all users in the rooms who are joined to any of these rooms.
SelectJoinedUsersSetForRooms(ctx context.Context, roomIDs []string) ([]string, error)
} }
// StrippedEvent represents a stripped event for returning extracted content values. // StrippedEvent represents a stripped event for returning extracted content values.