diff --git a/src/github.com/matrix-org/dendrite/clientapi/auth/authtypes/membership.go b/src/github.com/matrix-org/dendrite/clientapi/auth/authtypes/membership.go new file mode 100644 index 000000000..ad5312db6 --- /dev/null +++ b/src/github.com/matrix-org/dendrite/clientapi/auth/authtypes/membership.go @@ -0,0 +1,23 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package authtypes + +// Membership represents the relationship between a user and a room they're a +// member of +type Membership struct { + Localpart string + RoomID string + EventID string +} diff --git a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/membership_table.go b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/membership_table.go index 8eca4a574..a6d82860f 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/membership_table.go +++ b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/membership_table.go @@ -18,6 +18,7 @@ import ( "database/sql" "github.com/lib/pq" + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" ) const membershipSchema = ` @@ -45,16 +46,20 @@ const selectMembershipSQL = "" + "SELECT * from memberships WHERE localpart = $1 AND room_id = $2" const selectMembershipsByLocalpartSQL = "" + - "SELECT room_id FROM memberships WHERE localpart = $1" + "SELECT room_id, event_id FROM memberships WHERE localpart = $1" const deleteMembershipsByEventIDsSQL = "" + "DELETE FROM memberships WHERE event_id = ANY($1)" +const updateMembershipByEventIDSQL = "" + + "UPDATE memberships SET event_id = $2 WHERE event_id = $1" + type membershipStatements struct { deleteMembershipsByEventIDsStmt *sql.Stmt insertMembershipStmt *sql.Stmt selectMembershipByEventIDStmt *sql.Stmt selectMembershipsByLocalpartStmt *sql.Stmt + updateMembershipByEventIDStmt *sql.Stmt } func (s *membershipStatements) prepare(db *sql.DB) (err error) { @@ -71,6 +76,9 @@ func (s *membershipStatements) prepare(db *sql.DB) (err error) { if s.selectMembershipsByLocalpartStmt, err = db.Prepare(selectMembershipsByLocalpartSQL); err != nil { return } + if s.updateMembershipByEventIDStmt, err = db.Prepare(updateMembershipByEventIDSQL); err != nil { + return + } return } @@ -83,3 +91,29 @@ func (s *membershipStatements) deleteMembershipsByEventIDs(eventIDs []string, tx _, err = txn.Stmt(s.deleteMembershipsByEventIDsStmt).Exec(pq.StringArray(eventIDs)) return } + +func (s *membershipStatements) selectMembershipsByLocalpart(localpart string) (memberships []authtypes.Membership, err error) { + rows, err := s.selectMembershipsByLocalpartStmt.Query(localpart) + if err != nil { + return + } + + memberships = []authtypes.Membership{} + + defer rows.Close() + for rows.Next() { + var m authtypes.Membership + m.Localpart = localpart + if err := rows.Scan(&m.RoomID, &m.EventID); err != nil { + return nil, err + } + memberships = append(memberships, m) + } + + return +} + +func (s *membershipStatements) updateMembershipByEventID(oldEventID string, newEventID string) (err error) { + _, err = s.updateMembershipByEventIDStmt.Exec(oldEventID, newEventID) + return +} diff --git a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/storage.go b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/storage.go index 65d87d5a7..fcada6d81 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/storage.go +++ b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/storage.go @@ -151,6 +151,21 @@ func (d *Database) UpdateMemberships(eventsToAdd []gomatrixserverlib.Event, idsT }) } +// GetMembershipsByLocalpart returns an array containing the IDs of all the rooms +// a user matching a given localpart is a member of +// If no membership match the given localpart, returns an empty array +// If there was an issue during the retrieval, returns the SQL error +func (d *Database) GetMembershipsByLocalpart(localpart string) (memberships []authtypes.Membership, err error) { + return d.memberships.selectMembershipsByLocalpart(localpart) +} + +// UpdateMembership update the "join" membership event ID of a membership. +// This is useful in case of membership upgrade (e.g. profile update) +// If there was an issue during the update, returns the SQL error +func (d *Database) UpdateMembership(oldEventID string, newEventID string) error { + return d.memberships.updateMembershipByEventID(oldEventID, newEventID) +} + // newMembership will save a new membership in the database if the given state // event is a "join" membership event // If the event isn't a "join" membership event, does nothing diff --git a/src/github.com/matrix-org/dendrite/clientapi/readers/profile.go b/src/github.com/matrix-org/dendrite/clientapi/readers/profile.go index 6a3b4b377..939716ad8 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/readers/profile.go +++ b/src/github.com/matrix-org/dendrite/clientapi/readers/profile.go @@ -15,12 +15,18 @@ package readers import ( + "fmt" "net/http" + "time" + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" + "github.com/matrix-org/dendrite/clientapi/events" "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/producers" + "github.com/matrix-org/dendrite/common/config" + "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -39,6 +45,12 @@ type displayName struct { DisplayName string `json:"displayname"` } +type prevMembership struct { + PrevContent events.MemberContent `json:"prev_content"` + PrevID string `json:"replaces_state"` + UserID string `json:"prev_sender"` +} + // GetProfile implements GET /profile/{userID} func GetProfile( req *http.Request, accountDB *accounts.Database, userID string, @@ -93,8 +105,11 @@ func GetAvatarURL( // SetAvatarURL implements PUT /profile/{userID}/avatar_url func SetAvatarURL( req *http.Request, accountDB *accounts.Database, userID string, - producer *producers.UserUpdateProducer, + producer *producers.UserUpdateProducer, cfg *config.Dendrite, + rsProducer *producers.RoomserverProducer, queryAPI api.RoomserverQueryAPI, ) util.JSONResponse { + changedKey := "avatar_url" + var r avatarURL if resErr := httputil.UnmarshalJSONRequest(req, &r); resErr != nil { return *resErr @@ -116,11 +131,25 @@ func SetAvatarURL( return httputil.LogThenError(req, err) } - if err := accountDB.SetAvatarURL(localpart, r.AvatarURL); err != nil { + if err = accountDB.SetAvatarURL(localpart, r.AvatarURL); err != nil { return httputil.LogThenError(req, err) } - if err := producer.SendUpdate(userID, "avatar_url", oldProfile.AvatarURL, r.AvatarURL); err != nil { + memberships, err := accountDB.GetMembershipsByLocalpart(localpart) + if err != nil { + return httputil.LogThenError(req, err) + } + + events, err := buildMembershipEvents(memberships, accountDB, oldProfile, changedKey, r.AvatarURL, userID, cfg, queryAPI) + if err != nil { + return httputil.LogThenError(req, err) + } + + if err := rsProducer.SendEvents(events, cfg.Matrix.ServerName); err != nil { + return httputil.LogThenError(req, err) + } + + if err := producer.SendUpdate(userID, changedKey, oldProfile.AvatarURL, r.AvatarURL); err != nil { return httputil.LogThenError(req, err) } @@ -155,8 +184,11 @@ func GetDisplayName( // SetDisplayName implements PUT /profile/{userID}/displayname func SetDisplayName( req *http.Request, accountDB *accounts.Database, userID string, - producer *producers.UserUpdateProducer, + producer *producers.UserUpdateProducer, cfg *config.Dendrite, + rsProducer *producers.RoomserverProducer, queryAPI api.RoomserverQueryAPI, ) util.JSONResponse { + changedKey := "displayname" + var r displayName if resErr := httputil.UnmarshalJSONRequest(req, &r); resErr != nil { return *resErr @@ -178,11 +210,25 @@ func SetDisplayName( return httputil.LogThenError(req, err) } - if err := accountDB.SetDisplayName(localpart, r.DisplayName); err != nil { + if err = accountDB.SetDisplayName(localpart, r.DisplayName); err != nil { return httputil.LogThenError(req, err) } - if err := producer.SendUpdate(userID, "displayname", oldProfile.DisplayName, r.DisplayName); err != nil { + memberships, err := accountDB.GetMembershipsByLocalpart(localpart) + if err != nil { + return httputil.LogThenError(req, err) + } + + events, err := buildMembershipEvents(memberships, accountDB, oldProfile, changedKey, r.DisplayName, userID, cfg, queryAPI) + if err != nil { + return httputil.LogThenError(req, err) + } + + if err := rsProducer.SendEvents(events, cfg.Matrix.ServerName); err != nil { + return httputil.LogThenError(req, err) + } + + if err := producer.SendUpdate(userID, changedKey, oldProfile.DisplayName, r.DisplayName); err != nil { return httputil.LogThenError(req, err) } @@ -191,3 +237,97 @@ func SetDisplayName( JSON: struct{}{}, } } + +func buildMembershipEvents( + memberships []authtypes.Membership, db *accounts.Database, + oldProfile *authtypes.Profile, changedKey string, newValue string, + userID string, cfg *config.Dendrite, queryAPI api.RoomserverQueryAPI, +) ([]gomatrixserverlib.Event, error) { + ev := []gomatrixserverlib.Event{} + + for _, membership := range memberships { + prevContent := events.MemberContent{ + Membership: "join", + DisplayName: oldProfile.DisplayName, + AvatarURL: oldProfile.AvatarURL, + } + + prev := prevMembership{ + UserID: userID, + PrevID: membership.EventID, + PrevContent: prevContent, + } + + builder := gomatrixserverlib.EventBuilder{ + Sender: userID, + RoomID: membership.RoomID, + Type: "m.room.member", + StateKey: &userID, + } + + if err := builder.SetUnsigned(prev); err != nil { + return nil, err + } + + content := events.MemberContent{ + Membership: "join", + } + + if changedKey == "displayname" { + content.DisplayName = newValue + content.AvatarURL = oldProfile.AvatarURL + } else if changedKey == "avatar_url" { + content.DisplayName = oldProfile.DisplayName + content.AvatarURL = newValue + } + + if err := builder.SetContent(content); err != nil { + return nil, err + } + + eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(&builder) + if err != nil { + return nil, err + } + + // Ask the roomserver for information about this room + queryReq := api.QueryLatestEventsAndStateRequest{ + RoomID: membership.RoomID, + StateToFetch: eventsNeeded.Tuples(), + } + var queryRes api.QueryLatestEventsAndStateResponse + if queryErr := queryAPI.QueryLatestEventsAndState(&queryReq, &queryRes); queryErr != nil { + return nil, err + } + + authEvents := gomatrixserverlib.NewAuthEvents(nil) + + // Iterating the old way because range seems to mess things up. Might be + // worth investigating. + for i := 0; i < len(queryRes.StateEvents); i++ { + authEvents.AddEvent(&queryRes.StateEvents[i]) + } + + refs, err := eventsNeeded.AuthEventReferences(&authEvents) + if err != nil { + return nil, err + } + builder.AuthEvents = refs + + // Generate a new event ID and set it in the database + eventID := fmt.Sprintf("$%s:%s", util.RandomString(16), cfg.Matrix.ServerName) + if err = db.UpdateMembership(membership.EventID, eventID); err != nil { + return nil, err + } + + now := time.Now() + event, err := builder.Build(eventID, now, cfg.Matrix.ServerName, cfg.Matrix.KeyID, cfg.Matrix.PrivateKey) + if err != nil { + return nil, err + } + + ev = append(ev, event) + } + + return ev, nil +} diff --git a/src/github.com/matrix-org/dendrite/clientapi/routing/routing.go b/src/github.com/matrix-org/dendrite/clientapi/routing/routing.go index 8d6f024e2..048b95581 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/routing/routing.go +++ b/src/github.com/matrix-org/dendrite/clientapi/routing/routing.go @@ -185,7 +185,7 @@ func Setup( r0mux.Handle("/profile/{userID}/avatar_url", common.MakeAuthAPI("profile_avatar_url", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse { vars := mux.Vars(req) - return readers.SetAvatarURL(req, accountDB, vars["userID"], userUpdateProducer) + return readers.SetAvatarURL(req, accountDB, vars["userID"], userUpdateProducer, &cfg, producer, queryAPI) }), ).Methods("PUT", "OPTIONS") // Browsers use the OPTIONS HTTP method to check if the CORS policy allows @@ -201,7 +201,7 @@ func Setup( r0mux.Handle("/profile/{userID}/displayname", common.MakeAuthAPI("profile_displayname", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse { vars := mux.Vars(req) - return readers.SetDisplayName(req, accountDB, vars["userID"], userUpdateProducer) + return readers.SetDisplayName(req, accountDB, vars["userID"], userUpdateProducer, &cfg, producer, queryAPI) }), ).Methods("PUT", "OPTIONS") // Browsers use the OPTIONS HTTP method to check if the CORS policy allows