From c89321c3ff17cea1b70ddd2520f90dbe2ed72062 Mon Sep 17 00:00:00 2001 From: Alex Kursell Date: Tue, 22 Jun 2021 20:10:31 -0400 Subject: [PATCH] Respect published/shared rooms --- clientapi/routing/userdirectory.go | 39 ++------- roomserver/api/api.go | 2 - roomserver/api/api_trace.go | 7 -- roomserver/api/query.go | 9 --- roomserver/internal/input/input_membership.go | 15 +++- roomserver/internal/query/query.go | 13 --- roomserver/inthttp/client.go | 11 --- roomserver/inthttp/server.go | 13 --- roomserver/storage/interface.go | 2 - .../storage/postgres/membership_table.go | 28 +------ .../storage/shared/membership_updater.go | 8 +- roomserver/storage/shared/storage.go | 13 +-- .../20210622210900_add_displayname_column.go | 79 +++++++++++++++++++ .../storage/sqlite3/membership_table.go | 53 ++++--------- roomserver/storage/sqlite3/storage.go | 4 +- roomserver/storage/tables/interface.go | 3 +- 16 files changed, 131 insertions(+), 168 deletions(-) create mode 100644 roomserver/storage/sqlite3/deltas/20210622210900_add_displayname_column.go diff --git a/clientapi/routing/userdirectory.go b/clientapi/routing/userdirectory.go index ae954fef9..9f8f0184a 100644 --- a/clientapi/routing/userdirectory.go +++ b/clientapi/routing/userdirectory.go @@ -49,18 +49,18 @@ func SearchUserDirectory( Limited: false, } - // First start searching users in public rooms - userReq := &rsapi.QueryPublicUsersRequest{ + stateReq := &rsapi.QueryKnownUsersRequest{ + UserID: device.UserID, SearchString: searchString, Limit: limit, } - userRes := &rsapi.QueryPublicUsersResponse{} - if err := rsAPI.QueryPublicUsers(ctx, userReq, userRes); err != nil { - errRes := util.ErrorResponse(fmt.Errorf("rsAPI.QueryPublicUsers: %w", err)) + stateRes := &rsapi.QueryKnownUsersResponse{} + if err := rsAPI.QueryKnownUsers(ctx, stateReq, stateRes); err != nil { + errRes := util.ErrorResponse(fmt.Errorf("rsAPI.QueryKnownUsers: %w", err)) return &errRes } - for _, user := range userRes.Users { + for _, user := range stateRes.Users { if len(results) == limit { response.Limited = true break @@ -71,33 +71,6 @@ func SearchUserDirectory( } } - // Then, if we have enough room left in the response, - // start searching for known users from joined rooms. - - if len(results) <= limit { - stateReq := &rsapi.QueryKnownUsersRequest{ - UserID: device.UserID, - SearchString: searchString, - Limit: limit - len(results), - } - stateRes := &rsapi.QueryKnownUsersResponse{} - if err := rsAPI.QueryKnownUsers(ctx, stateReq, stateRes); err != nil { - errRes := util.ErrorResponse(fmt.Errorf("rsAPI.QueryKnownUsers: %w", err)) - return &errRes - } - - for _, user := range stateRes.Users { - if len(results) == limit { - response.Limited = true - break - } - - if _, ok := results[user.UserID]; !ok { - results[user.UserID] = user - } - } - } - for _, result := range results { response.Results = append(response.Results, result) } diff --git a/roomserver/api/api.go b/roomserver/api/api.go index 2ba1a1b9f..72e406ee8 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -160,8 +160,6 @@ type RoomserverInternalAPI interface { QuerySharedUsers(ctx context.Context, req *QuerySharedUsersRequest, res *QuerySharedUsersResponse) error // QueryKnownUsers returns a list of users that we know about from our joined rooms. QueryKnownUsers(ctx context.Context, req *QueryKnownUsersRequest, res *QueryKnownUsersResponse) error - // QueryPublicUsers returns a list of users in at least 1 public room. - QueryPublicUsers(ctx context.Context, req *QueryPublicUsersRequest, res *QueryPublicUsersResponse) error // QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs. QueryServerBannedFromRoom(ctx context.Context, req *QueryServerBannedFromRoomRequest, res *QueryServerBannedFromRoomResponse) error diff --git a/roomserver/api/api_trace.go b/roomserver/api/api_trace.go index 6da433c8d..1a2b9a490 100644 --- a/roomserver/api/api_trace.go +++ b/roomserver/api/api_trace.go @@ -332,13 +332,6 @@ func (t *RoomserverInternalAPITrace) QueryKnownUsers(ctx context.Context, req *Q return err } -// QueryPublicUsers returns a list of users that are in at least one public room. -func (t *RoomserverInternalAPITrace) QueryPublicUsers(ctx context.Context, req *QueryPublicUsersRequest, res *QueryPublicUsersResponse) error { - err := t.Impl.QueryPublicUsers(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("QueryPublicUsers req=%+v res=%+v", js(req), js(res)) - return err -} - // QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs. func (t *RoomserverInternalAPITrace) QueryServerBannedFromRoom(ctx context.Context, req *QueryServerBannedFromRoomRequest, res *QueryServerBannedFromRoomResponse) error { err := t.Impl.QueryServerBannedFromRoom(ctx, req, res) diff --git a/roomserver/api/query.go b/roomserver/api/query.go index f1223e362..af35f7e72 100644 --- a/roomserver/api/query.go +++ b/roomserver/api/query.go @@ -348,15 +348,6 @@ type QueryKnownUsersResponse struct { Users []authtypes.FullyQualifiedProfile `json:"profiles"` } -type QueryPublicUsersRequest struct { - SearchString string `json:"search_string"` - Limit int `json:"limit"` -} - -type QueryPublicUsersResponse struct { - Users []authtypes.FullyQualifiedProfile `json:"profiles"` -} - type QueryServerBannedFromRoomRequest struct { ServerName gomatrixserverlib.ServerName `json:"server_name"` RoomID string `json:"room_id"` diff --git a/roomserver/internal/input/input_membership.go b/roomserver/internal/input/input_membership.go index 44435bfd9..33a7358e8 100644 --- a/roomserver/internal/input/input_membership.go +++ b/roomserver/internal/input/input_membership.go @@ -155,11 +155,22 @@ func (r *Inputer) isLocalTarget(event *gomatrixserverlib.Event) bool { func updateToJoinMembership( mu *shared.MembershipUpdater, add *gomatrixserverlib.Event, updates []api.OutputEvent, ) ([]api.OutputEvent, error) { + // Extract the displayname, if there is one, from the membership event + // so that we can add it as a queryable field to the membership table + memberEvent, err := gomatrixserverlib.NewMemberContentFromEvent(add) + if err != nil { + return nil, err + } + var displayname *string + if memberEvent.DisplayName != "" { + displayname = &memberEvent.DisplayName + } + // If the user is already marked as being joined, we call SetToJoin to update // the event ID then we can return immediately. Retired is ignored as there // is no invite event to retire. if mu.IsJoin() { - _, err := mu.SetToJoin(add.Sender(), add.EventID(), true) + _, err := mu.SetToJoin(add.Sender(), add.EventID(), displayname, true) if err != nil { return nil, err } @@ -169,7 +180,7 @@ func updateToJoinMembership( // are active for that user. We notify the consumers that the invites have // been retired using a special event, even though they could infer this // by studying the state changes in the room event stream. - retired, err := mu.SetToJoin(add.Sender(), add.EventID(), false) + retired, err := mu.SetToJoin(add.Sender(), add.EventID(), displayname, false) if err != nil { return nil, err } diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index 357b24066..408f9766e 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -683,19 +683,6 @@ func (r *Queryer) QueryKnownUsers(ctx context.Context, req *api.QueryKnownUsersR return nil } -func (r *Queryer) QueryPublicUsers(ctx context.Context, req *api.QueryPublicUsersRequest, res *api.QueryPublicUsersResponse) error { - users, err := r.DB.GetPublicUsers(ctx, req.SearchString, req.Limit) - if err != nil { - return err - } - for _, user := range users { - res.Users = append(res.Users, authtypes.FullyQualifiedProfile{ - UserID: user, - }) - } - return nil -} - func (r *Queryer) QueryBulkStateContent(ctx context.Context, req *api.QueryBulkStateContentRequest, res *api.QueryBulkStateContentResponse) error { events, err := r.DB.GetBulkStateContent(ctx, req.RoomIDs, req.StateTuples, req.AllowWildcards) if err != nil { diff --git a/roomserver/inthttp/client.go b/roomserver/inthttp/client.go index ba65b9749..6774d102d 100644 --- a/roomserver/inthttp/client.go +++ b/roomserver/inthttp/client.go @@ -55,7 +55,6 @@ const ( RoomserverQueryBulkStateContentPath = "/roomserver/queryBulkStateContent" RoomserverQuerySharedUsersPath = "/roomserver/querySharedUsers" RoomserverQueryKnownUsersPath = "/roomserver/queryKnownUsers" - RoomserverQueryPublicUsersPath = "/roomserver/queryPublicUsers" RoomserverQueryServerBannedFromRoomPath = "/roomserver/queryServerBannedFromRoom" RoomserverQueryAuthChainPath = "/roomserver/queryAuthChain" ) @@ -522,16 +521,6 @@ func (h *httpRoomserverInternalAPI) QueryKnownUsers( return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) } -func (h *httpRoomserverInternalAPI) QueryPublicUsers( - ctx context.Context, req *api.QueryPublicUsersRequest, res *api.QueryPublicUsersResponse, -) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryPublicUsers") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverQueryPublicUsersPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) -} - func (h *httpRoomserverInternalAPI) QueryAuthChain( ctx context.Context, req *api.QueryAuthChainRequest, res *api.QueryAuthChainResponse, ) error { diff --git a/roomserver/inthttp/server.go b/roomserver/inthttp/server.go index 7f008d8d7..bf319262f 100644 --- a/roomserver/inthttp/server.go +++ b/roomserver/inthttp/server.go @@ -452,19 +452,6 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) - internalAPIMux.Handle(RoomserverQueryPublicUsersPath, - httputil.MakeInternalAPI("queryPublicUsers", func(req *http.Request) util.JSONResponse { - request := api.QueryPublicUsersRequest{} - response := api.QueryPublicUsersResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := r.QueryPublicUsers(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) internalAPIMux.Handle(RoomserverQueryServerBannedFromRoomPath, httputil.MakeInternalAPI("queryServerBannedFromRoom", func(req *http.Request) util.JSONResponse { request := api.QueryServerBannedFromRoomRequest{} diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index d3e16485e..d2b0e75c9 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -156,8 +156,6 @@ type Database interface { JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) // GetKnownUsers searches all users that userID knows about. GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error) - // GetPublicUsers searches all users that are in a public room. - GetPublicUsers(ctx context.Context, searchString string, limit int) ([]string, error) // GetKnownRooms returns a list of all rooms we know about. GetKnownRooms(ctx context.Context) ([]string, error) // ForgetRoom sets a flag in the membership table, that the user wishes to forget a specific room diff --git a/roomserver/storage/postgres/membership_table.go b/roomserver/storage/postgres/membership_table.go index c51bd5674..c7d48d6e9 100644 --- a/roomserver/storage/postgres/membership_table.go +++ b/roomserver/storage/postgres/membership_table.go @@ -124,13 +124,6 @@ var selectKnownUsersSQL = "" + " SELECT DISTINCT room_nid FROM roomserver_membership WHERE target_nid=$1 AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + ") AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " AND event_state_key LIKE $2 LIMIT $3" -// Select users in public rooms -const selectPublicUsersSQL = "" + - "SELECT DISTINCT event_state_key FROM roomserver_event_state_keys INNER JOIN roomserver_membership ON " + - "roomserver_membership.target_nid = roomserver_event_state_keys.event_state_key_nid INNER JOIN roomserver_rooms ON " + - "roomserver_rooms.room_nid = roomserver_membership.room_nid INNER JOIN roomserver_published ON roomserver_published.room_id = roomserver_rooms.room_id " + - "AND event_state_key LIKE $1 LIMIT $2" - type membershipStatements struct { insertMembershipStmt *sql.Stmt selectMembershipForUpdateStmt *sql.Stmt @@ -143,7 +136,6 @@ type membershipStatements struct { selectRoomsWithMembershipStmt *sql.Stmt selectJoinedUsersSetForRoomsStmt *sql.Stmt selectKnownUsersStmt *sql.Stmt - selectPublicUsersStmt *sql.Stmt updateMembershipForgetRoomStmt *sql.Stmt } @@ -167,7 +159,6 @@ func prepareMembershipTable(db *sql.DB) (tables.Membership, error) { {&s.selectRoomsWithMembershipStmt, selectRoomsWithMembershipSQL}, {&s.selectJoinedUsersSetForRoomsStmt, selectJoinedUsersSetForRoomsSQL}, {&s.selectKnownUsersStmt, selectKnownUsersSQL}, - {&s.selectPublicUsersStmt, selectPublicUsersSQL}, {&s.updateMembershipForgetRoomStmt, updateMembershipForgetRoom}, }.Prepare(db) } @@ -257,7 +248,7 @@ func (s *membershipStatements) SelectMembershipsFromRoomAndMembership( func (s *membershipStatements) UpdateMembership( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership tables.MembershipState, - eventNID types.EventNID, forgotten bool, + eventNID types.EventNID, forgotten bool, displayname *string, ) error { _, err := sqlutil.TxStmt(txn, s.updateMembershipStmt).ExecContext( ctx, roomNID, targetUserNID, senderUserNID, membership, eventNID, forgotten, @@ -323,23 +314,6 @@ func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID type return result, rows.Err() } -func (s *membershipStatements) SelectPublicUsers(ctx context.Context, searchString string, limit int) ([]string, error) { - rows, err := s.selectPublicUsersStmt.QueryContext(ctx, fmt.Sprintf("%%%s%%", searchString), limit) - if err != nil { - return nil, err - } - result := []string{} - defer internal.CloseAndLogIfError(ctx, rows, "SelectPublicUsers: rows.close() failed") - for rows.Next() { - var userID string - if err := rows.Scan(&userID); err != nil { - return nil, err - } - result = append(result, userID) - } - return result, rows.Err() -} - func (s *membershipStatements) UpdateForgetMembership( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, diff --git a/roomserver/storage/shared/membership_updater.go b/roomserver/storage/shared/membership_updater.go index 57f3a520a..6aace6f78 100644 --- a/roomserver/storage/shared/membership_updater.go +++ b/roomserver/storage/shared/membership_updater.go @@ -101,7 +101,7 @@ func (u *MembershipUpdater) SetToInvite(event gomatrixserverlib.Event) (bool, er return fmt.Errorf("u.d.InvitesTable.InsertInviteEvent: %w", err) } if u.membership != tables.MembershipStateInvite { - if err = u.d.MembershipTable.UpdateMembership(u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateInvite, 0, false); err != nil { + if err = u.d.MembershipTable.UpdateMembership(u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateInvite, 0, false, nil); err != nil { return fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err) } } @@ -111,7 +111,7 @@ func (u *MembershipUpdater) SetToInvite(event gomatrixserverlib.Event) (bool, er } // SetToJoin implements types.MembershipUpdater -func (u *MembershipUpdater) SetToJoin(senderUserID string, eventID string, isUpdate bool) ([]string, error) { +func (u *MembershipUpdater) SetToJoin(senderUserID string, eventID string, displayname *string, isUpdate bool) ([]string, error) { var inviteEventIDs []string err := u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error { @@ -137,7 +137,7 @@ func (u *MembershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd } if u.membership != tables.MembershipStateJoin || isUpdate { - if err = u.d.MembershipTable.UpdateMembership(u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateJoin, nIDs[eventID], false); err != nil { + if err = u.d.MembershipTable.UpdateMembership(u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateJoin, nIDs[eventID], false, displayname); err != nil { return fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err) } } @@ -171,7 +171,7 @@ func (u *MembershipUpdater) SetToLeave(senderUserID string, eventID string) ([]s } if u.membership != tables.MembershipStateLeaveOrBan { - if err = u.d.MembershipTable.UpdateMembership(u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateLeaveOrBan, nIDs[eventID], false); err != nil { + if err = u.d.MembershipTable.UpdateMembership(u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateLeaveOrBan, nIDs[eventID], false, nil); err != nil { return fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err) } } diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 5092ca8a2..7d9760290 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -1063,16 +1063,17 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) func (d *Database) GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error) { stateKeyNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, nil, userID) if err != nil { - return nil, err + if err == sql.ErrNoRows { + // Happens if a search is performed before user joins any rooms + // NID is never -1, so SelectKnownUsers will only pull users in public rooms + stateKeyNID = -1 + } else { + return nil, err + } } return d.MembershipTable.SelectKnownUsers(ctx, stateKeyNID, searchString, limit) } -// Get users in public rooms -func (d *Database) GetPublicUsers(ctx context.Context, searchString string, limit int) ([]string, error) { - return d.MembershipTable.SelectPublicUsers(ctx, searchString, limit) -} - // GetKnownRooms returns a list of all rooms we know about. func (d *Database) GetKnownRooms(ctx context.Context) ([]string, error) { return d.RoomsTable.SelectRoomIDs(ctx) diff --git a/roomserver/storage/sqlite3/deltas/20210622210900_add_displayname_column.go b/roomserver/storage/sqlite3/deltas/20210622210900_add_displayname_column.go new file mode 100644 index 000000000..cd7863560 --- /dev/null +++ b/roomserver/storage/sqlite3/deltas/20210622210900_add_displayname_column.go @@ -0,0 +1,79 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package deltas + +import ( + "database/sql" + "fmt" + + "github.com/matrix-org/dendrite/internal/sqlutil" +) + +func LoadAddDisplaynameColumn(m *sqlutil.Migrations) { + m.AddMigration(UpAddDisplaynameColumn, DownAddDisplaynameColumn) +} + +func UpAddDisplaynameColumn(tx *sql.Tx) error { + _, err := tx.Exec(` ALTER TABLE roomserver_membership RENAME TO roomserver_membership_tmp; +CREATE TABLE IF NOT EXISTS roomserver_membership ( + room_nid INTEGER NOT NULL, + target_nid INTEGER NOT NULL, + sender_nid INTEGER NOT NULL DEFAULT 0, + membership_nid INTEGER NOT NULL DEFAULT 1, + event_nid INTEGER NOT NULL DEFAULT 0, + target_local BOOLEAN NOT NULL DEFAULT false, + forgotten BOOLEAN NOT NULL DEFAULT false, + displayname TEXT, + UNIQUE (room_nid, target_nid) +); +INSERT + INTO roomserver_membership ( + room_nid, target_nid, sender_nid, membership_nid, event_nid, target_local, forgotten + ) SELECT + room_nid, target_nid, sender_nid, membership_nid, event_nid, target_local, forgotten + FROM roomserver_membership_tmp +; +DROP TABLE roomserver_membership_tmp;`) + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } + return nil +} + +func DownAddDisplaynameColumn(tx *sql.Tx) error { + _, err := tx.Exec(` ALTER TABLE roomserver_membership RENAME TO roomserver_membership_tmp; +CREATE TABLE IF NOT EXISTS roomserver_membership ( + room_nid INTEGER NOT NULL, + target_nid INTEGER NOT NULL, + sender_nid INTEGER NOT NULL DEFAULT 0, + membership_nid INTEGER NOT NULL DEFAULT 1, + event_nid INTEGER NOT NULL DEFAULT 0, + target_local BOOLEAN NOT NULL DEFAULT false, + forgotten BOOLEAN NOT NULL DEFAULT false, + UNIQUE (room_nid, target_nid) +); +INSERT + INTO roomserver_membership ( + room_nid, target_nid, sender_nid, membership_nid, event_nid, target_local, forgotten + ) SELECT + room_nid, target_nid, sender_nid, membership_nid, event_nid, target_local, forgotten + FROM roomserver_membership_tmp +; +DROP TABLE roomserver_membership_tmp;`) + if err != nil { + return fmt.Errorf("failed to execute downgrade: %w", err) + } + return nil +} diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go index d9e8b76db..c87f7c5e8 100644 --- a/roomserver/storage/sqlite3/membership_table.go +++ b/roomserver/storage/sqlite3/membership_table.go @@ -37,6 +37,7 @@ const membershipSchema = ` event_nid INTEGER NOT NULL DEFAULT 0, target_local BOOLEAN NOT NULL DEFAULT false, forgotten BOOLEAN NOT NULL DEFAULT false, + displayname TEXT, UNIQUE (room_nid, target_nid) ); ` @@ -80,8 +81,8 @@ const selectMembershipForUpdateSQL = "" + " WHERE room_nid = $1 AND target_nid = $2" const updateMembershipSQL = "" + - "UPDATE roomserver_membership SET sender_nid = $1, membership_nid = $2, event_nid = $3, forgotten = $4" + - " WHERE room_nid = $5 AND target_nid = $6" + "UPDATE roomserver_membership SET sender_nid = $1, membership_nid = $2, event_nid = $3, forgotten = $4, displayname = $5" + + " WHERE room_nid = $6 AND target_nid = $7" const updateMembershipForgetRoom = "" + "UPDATE roomserver_membership SET forgotten = $1" + @@ -94,18 +95,16 @@ const selectRoomsWithMembershipSQL = "" + // joined to. Since this information is used to populate the user directory, we will // only return users that the user would ordinarily be able to see anyway. var selectKnownUsersSQL = "" + - "SELECT DISTINCT event_state_key FROM roomserver_membership INNER JOIN roomserver_event_state_keys ON " + - "roomserver_membership.target_nid = roomserver_event_state_keys.event_state_key_nid" + - " WHERE room_nid IN (" + - " SELECT DISTINCT room_nid FROM roomserver_membership WHERE target_nid=$1 AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + - ") AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " AND event_state_key LIKE $2 LIMIT $3" - -// Select users in public rooms -const selectPublicUsersSQL = "" + - "SELECT DISTINCT event_state_key FROM roomserver_event_state_keys INNER JOIN roomserver_membership ON " + - "roomserver_membership.target_nid = roomserver_event_state_keys.event_state_key_nid INNER JOIN roomserver_rooms ON " + - "roomserver_rooms.room_nid = roomserver_membership.room_nid INNER JOIN roomserver_published ON roomserver_published.room_id = roomserver_rooms.room_id " + - "AND event_state_key LIKE $1 LIMIT $2" + "SELECT DISTINCT event_state_key, displayname FROM roomserver_membership INNER JOIN roomserver_event_state_keys ON " + + "roomserver_membership.target_nid = roomserver_event_state_keys.event_state_key_nid " + + "WHERE room_nid IN (" + + " SELECT DISTINCT room_nid FROM roomserver_membership WHERE target_nid = $1 AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + + " UNION " + + " SELECT DISTINCT room_nid from roomserver_rooms INNER JOIN roomserver_published ON " + + " roomserver_rooms.room_id = roomserver_published.room_id WHERE published = true " + + ") AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " AND ( " + + " event_state_key LIKE $2 OR displayname LIKE $2" + + ") LIMIT $3" type membershipStatements struct { db *sql.DB @@ -119,7 +118,6 @@ type membershipStatements struct { selectRoomsWithMembershipStmt *sql.Stmt updateMembershipStmt *sql.Stmt selectKnownUsersStmt *sql.Stmt - selectPublicUsersStmt *sql.Stmt updateMembershipForgetRoomStmt *sql.Stmt } @@ -144,7 +142,6 @@ func prepareMembershipTable(db *sql.DB) (tables.Membership, error) { {&s.updateMembershipStmt, updateMembershipSQL}, {&s.selectRoomsWithMembershipStmt, selectRoomsWithMembershipSQL}, {&s.selectKnownUsersStmt, selectKnownUsersSQL}, - {&s.selectPublicUsersStmt, selectPublicUsersSQL}, {&s.updateMembershipForgetRoomStmt, updateMembershipForgetRoom}, }.Prepare(db) } @@ -235,11 +232,11 @@ func (s *membershipStatements) SelectMembershipsFromRoomAndMembership( func (s *membershipStatements) UpdateMembership( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership tables.MembershipState, - eventNID types.EventNID, forgotten bool, + eventNID types.EventNID, forgotten bool, displayname *string, ) error { stmt := sqlutil.TxStmt(txn, s.updateMembershipStmt) _, err := stmt.ExecContext( - ctx, senderUserNID, membership, eventNID, forgotten, roomNID, targetUserNID, + ctx, senderUserNID, membership, eventNID, forgotten, displayname, roomNID, targetUserNID, ) return err } @@ -295,24 +292,8 @@ func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID type defer internal.CloseAndLogIfError(ctx, rows, "SelectKnownUsers: rows.close() failed") for rows.Next() { var userID string - if err := rows.Scan(&userID); err != nil { - return nil, err - } - result = append(result, userID) - } - return result, rows.Err() -} - -func (s *membershipStatements) SelectPublicUsers(ctx context.Context, searchString string, limit int) ([]string, error) { - rows, err := s.selectPublicUsersStmt.QueryContext(ctx, fmt.Sprintf("%%%s%%", searchString), limit) - if err != nil { - return nil, err - } - result := []string{} - defer internal.CloseAndLogIfError(ctx, rows, "SelectPublicUsers: rows.close() failed") - for rows.Next() { - var userID string - if err := rows.Scan(&userID); err != nil { + var displayname string + if err := rows.Scan(&userID, &displayname); err != nil { return nil, err } result = append(result, userID) diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index c07ab507a..5c32954c1 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -18,6 +18,7 @@ package sqlite3 import ( "context" "database/sql" + "fmt" _ "github.com/mattn/go-sqlite3" @@ -63,6 +64,7 @@ func Open(dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) m := sqlutil.NewMigrations() deltas.LoadAddForgottenColumn(m) deltas.LoadStateBlocksRefactor(m) + deltas.LoadAddDisplaynameColumn(m) if err := m.RunDeltas(db, dbProperties); err != nil { return nil, err } @@ -70,7 +72,7 @@ func Open(dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) // Then prepare the statements. Now that the migrations have run, any columns referred // to in the database code should now exist. if err := d.prepare(db, cache); err != nil { - return nil, err + return nil, fmt.Errorf("Error preparing statements: %v", err) } return &d, nil diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index 1dcf16f2e..51b2a51da 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -128,13 +128,12 @@ type Membership interface { SelectMembershipFromRoomAndTarget(ctx context.Context, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (types.EventNID, MembershipState, bool, error) SelectMembershipsFromRoom(ctx context.Context, roomNID types.RoomNID, localOnly bool) (eventNIDs []types.EventNID, err error) SelectMembershipsFromRoomAndMembership(ctx context.Context, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error) - UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID, forgotten bool) error + UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID, forgotten bool, displayname *string) error SelectRoomsWithMembership(ctx context.Context, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error) // SelectJoinedUsersSetForRooms returns the set of all users in the rooms who are joined to any of these rooms, along with the // counts of how many rooms they are joined. SelectJoinedUsersSetForRooms(ctx context.Context, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) - SelectPublicUsers(ctx context.Context, searchString string, limit int) ([]string, error) UpdateForgetMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool) error }