diff --git a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/membership_table.go b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/membership_table.go index 42922743b..9652cdac7 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/membership_table.go +++ b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/membership_table.go @@ -44,9 +44,6 @@ const insertMembershipSQL = ` ON CONFLICT (localpart, room_id) DO UPDATE SET event_id = EXCLUDED.event_id ` -const selectMembershipSQL = "" + - "SELECT * from account_memberships WHERE localpart = $1 AND room_id = $2" - const selectMembershipsByLocalpartSQL = "" + "SELECT room_id, event_id FROM account_memberships WHERE localpart = $1" diff --git a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/storage.go b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/storage.go index fb98946a0..76b9a4dd9 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/storage.go +++ b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/storage.go @@ -121,7 +121,8 @@ func (d *Database) SetPartitionOffset(topic string, partition int32, offset int6 } // SaveMembership saves the user matching a given localpart as a member of a given -// room. It also stores the ID of the `join` membership event. +// room. It also stores the ID of the membership event and a flag on whether the user +// is still in the room. // If a membership already exists between the user and the room, or of the // insert fails, returns the SQL error func (d *Database) SaveMembership(localpart string, roomID string, eventID string, txn *sql.Tx) error { @@ -156,23 +157,19 @@ func (d *Database) UpdateMemberships(eventsToAdd []gomatrixserverlib.Event, idsT }) } -// GetMembershipsByLocalpart returns an array containing the IDs of all the rooms -// a user matching a given localpart is a member of +// GetMembershipsByLocalpart returns an array containing the memberships for all +// the rooms a user matching a given localpart is a member of // If no membership match the given localpart, returns an empty array // If there was an issue during the retrieval, returns the SQL error func (d *Database) GetMembershipsByLocalpart(localpart string) (memberships []authtypes.Membership, err error) { return d.memberships.selectMembershipsByLocalpart(localpart) } -// UpdateMembership update the "join" membership event ID of a membership. -// This is useful in case of membership upgrade (e.g. profile update) -// If there was an issue during the update, returns the SQL error -func (d *Database) UpdateMembership(oldEventID string, newEventID string) error { - return d.memberships.updateMembershipByEventID(oldEventID, newEventID) -} - -// newMembership will save a new membership in the database if the given state -// event is a "join" membership event +// newMembership will save a new membership in the database, with a flag on whether +// the user is still in the room. This flag is set to true if the given state +// event is a "join" membership event and false if the event is a "leave" or "ban" +// membership. If the event isn't a m.room.member event with one of these three +// values, does nothing. // If the event isn't a "join" membership event, does nothing // If an error occurred, returns it func (d *Database) newMembership(ev gomatrixserverlib.Event, txn *sql.Tx) error { diff --git a/src/github.com/matrix-org/dendrite/clientapi/readers/memberships.go b/src/github.com/matrix-org/dendrite/clientapi/readers/memberships.go new file mode 100644 index 000000000..c734e3b34 --- /dev/null +++ b/src/github.com/matrix-org/dendrite/clientapi/readers/memberships.go @@ -0,0 +1,55 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package readers + +import ( + "net/http" + + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" + "github.com/matrix-org/dendrite/clientapi/httputil" + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/common/config" + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/util" +) + +// GetMemberships implements GET /rooms/{roomId}/members +func GetMemberships( + req *http.Request, device *authtypes.Device, roomID string, + accountDB *accounts.Database, cfg config.Dendrite, + queryAPI api.RoomserverQueryAPI, +) util.JSONResponse { + queryReq := api.QueryMembershipsForRoomRequest{ + RoomID: roomID, + Sender: device.UserID, + } + var queryRes api.QueryMembershipsForRoomResponse + if err := queryAPI.QueryMembershipsForRoom(&queryReq, &queryRes); err != nil { + return httputil.LogThenError(req, err) + } + + if !queryRes.HasBeenInRoom { + return util.JSONResponse{ + Code: 403, + JSON: jsonerror.Forbidden("You aren't a member of the room and weren't previously a member of the room."), + } + } + + return util.JSONResponse{ + Code: 200, + JSON: queryRes.JoinEvents, + } +} diff --git a/src/github.com/matrix-org/dendrite/clientapi/routing/routing.go b/src/github.com/matrix-org/dendrite/clientapi/routing/routing.go index eb84b2a9e..8a5799c0e 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/routing/routing.go +++ b/src/github.com/matrix-org/dendrite/clientapi/routing/routing.go @@ -313,6 +313,13 @@ func Setup( }), ) + r0mux.Handle("/rooms/{roomID}/members", + common.MakeAuthAPI("rooms_members", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse { + vars := mux.Vars(req) + return readers.GetMemberships(req, device, vars["roomID"], accountDB, cfg, queryAPI) + }), + ) + r0mux.Handle("/rooms/{roomID}/read_markers", common.MakeAPI("rooms_read_markers", func(req *http.Request) util.JSONResponse { // TODO: return the read_markers. diff --git a/src/github.com/matrix-org/dendrite/common/sql.go b/src/github.com/matrix-org/dendrite/common/sql.go index cabbe6662..c2fb753fc 100644 --- a/src/github.com/matrix-org/dendrite/common/sql.go +++ b/src/github.com/matrix-org/dendrite/common/sql.go @@ -18,6 +18,24 @@ import ( "database/sql" ) +// A Transaction is something that can be committed or rolledback. +type Transaction interface { + // Commit the transaction + Commit() error + // Rollback the transaction. + Rollback() error +} + +// EndTransaction ends a transaction. +// If the transaction succeeded then it is committed, otherwise it is rolledback. +func EndTransaction(txn Transaction, succeeded *bool) { + if *succeeded { + txn.Commit() + } else { + txn.Rollback() + } +} + // WithTransaction runs a block of code passing in an SQL transaction // If the code returns an error or panics then the transactions is rolledback // Otherwise the transaction is committed. @@ -26,16 +44,25 @@ func WithTransaction(db *sql.DB, fn func(txn *sql.Tx) error) (err error) { if err != nil { return } - defer func() { - if r := recover(); r != nil { - txn.Rollback() - panic(r) - } else if err != nil { - txn.Rollback() - } else { - err = txn.Commit() - } - }() + succeeded := false + defer EndTransaction(txn, &succeeded) + err = fn(txn) + if err != nil { + return + } + + succeeded = true return } + +// TxStmt wraps an SQL stmt inside an optional transaction. +// If the transaction is nil then it returns the original statement that will +// run outside of a transaction. +// Otherwise returns a copy of the statement that will run inside the transaction. +func TxStmt(transaction *sql.Tx, statement *sql.Stmt) *sql.Stmt { + if transaction != nil { + statement = transaction.Stmt(statement) + } + return statement +} diff --git a/src/github.com/matrix-org/dendrite/common/test/server.go b/src/github.com/matrix-org/dendrite/common/test/server.go index d990d4105..8e1af2041 100644 --- a/src/github.com/matrix-org/dendrite/common/test/server.go +++ b/src/github.com/matrix-org/dendrite/common/test/server.go @@ -94,6 +94,8 @@ func StartProxy(bindAddr string, cfg *config.Dendrite) (*exec.Cmd, chan error) { "--sync-api-server-url", "http://" + string(cfg.Listen.SyncAPI), "--client-api-server-url", "http://" + string(cfg.Listen.ClientAPI), "--media-api-server-url", "http://" + string(cfg.Listen.MediaAPI), + "--tls-cert", "server.crt", + "--tls-key", "server.key", } return CreateBackgroundCommand( filepath.Join(filepath.Dir(os.Args[0]), "client-api-proxy"), diff --git a/src/github.com/matrix-org/dendrite/federationsender/storage/joined_hosts_table.go b/src/github.com/matrix-org/dendrite/federationsender/storage/joined_hosts_table.go index 7ba1b0b07..fffcc7f3f 100644 --- a/src/github.com/matrix-org/dendrite/federationsender/storage/joined_hosts_table.go +++ b/src/github.com/matrix-org/dendrite/federationsender/storage/joined_hosts_table.go @@ -18,6 +18,7 @@ import ( "database/sql" "github.com/lib/pq" + "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/federationsender/types" "github.com/matrix-org/gomatrixserverlib" ) @@ -79,18 +80,18 @@ func (s *joinedHostsStatements) prepare(db *sql.DB) (err error) { func (s *joinedHostsStatements) insertJoinedHosts( txn *sql.Tx, roomID, eventID string, serverName gomatrixserverlib.ServerName, ) error { - _, err := txn.Stmt(s.insertJoinedHostsStmt).Exec(roomID, eventID, serverName) + _, err := common.TxStmt(txn, s.insertJoinedHostsStmt).Exec(roomID, eventID, serverName) return err } func (s *joinedHostsStatements) deleteJoinedHosts(txn *sql.Tx, eventIDs []string) error { - _, err := txn.Stmt(s.deleteJoinedHostsStmt).Exec(pq.StringArray(eventIDs)) + _, err := common.TxStmt(txn, s.deleteJoinedHostsStmt).Exec(pq.StringArray(eventIDs)) return err } func (s *joinedHostsStatements) selectJoinedHosts(txn *sql.Tx, roomID string, ) ([]types.JoinedHost, error) { - rows, err := txn.Stmt(s.selectJoinedHostsStmt).Query(roomID) + rows, err := common.TxStmt(txn, s.selectJoinedHostsStmt).Query(roomID) if err != nil { return nil, err } diff --git a/src/github.com/matrix-org/dendrite/federationsender/storage/room_table.go b/src/github.com/matrix-org/dendrite/federationsender/storage/room_table.go index daac7ddf4..bcc0bb1df 100644 --- a/src/github.com/matrix-org/dendrite/federationsender/storage/room_table.go +++ b/src/github.com/matrix-org/dendrite/federationsender/storage/room_table.go @@ -16,6 +16,8 @@ package storage import ( "database/sql" + + "github.com/matrix-org/dendrite/common" ) const roomSchema = ` @@ -65,7 +67,7 @@ func (s *roomStatements) prepare(db *sql.DB) (err error) { // insertRoom inserts the room if it didn't already exist. // If the room didn't exist then last_event_id is set to the empty string. func (s *roomStatements) insertRoom(txn *sql.Tx, roomID string) error { - _, err := txn.Stmt(s.insertRoomStmt).Exec(roomID) + _, err := common.TxStmt(txn, s.insertRoomStmt).Exec(roomID) return err } @@ -74,7 +76,7 @@ func (s *roomStatements) insertRoom(txn *sql.Tx, roomID string) error { // exists by calling insertRoom first. func (s *roomStatements) selectRoomForUpdate(txn *sql.Tx, roomID string) (string, error) { var lastEventID string - err := txn.Stmt(s.selectRoomForUpdateStmt).QueryRow(roomID).Scan(&lastEventID) + err := common.TxStmt(txn, s.selectRoomForUpdateStmt).QueryRow(roomID).Scan(&lastEventID) if err != nil { return "", err } @@ -84,6 +86,6 @@ func (s *roomStatements) selectRoomForUpdate(txn *sql.Tx, roomID string) (string // updateRoom updates the last_event_id for the room. selectRoomForUpdate should // have already been called earlier within the transaction. func (s *roomStatements) updateRoom(txn *sql.Tx, roomID, lastEventID string) error { - _, err := txn.Stmt(s.updateRoomStmt).Exec(roomID, lastEventID) + _, err := common.TxStmt(txn, s.updateRoomStmt).Exec(roomID, lastEventID) return err } diff --git a/src/github.com/matrix-org/dendrite/roomserver/api/input.go b/src/github.com/matrix-org/dendrite/roomserver/api/input.go index 558eb28c4..cbe7399ba 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/api/input.go +++ b/src/github.com/matrix-org/dendrite/roomserver/api/input.go @@ -68,9 +68,17 @@ type InputRoomEvent struct { SendAsServer string `json:"send_as_server"` } +// InputInviteEvent is a matrix invite event received over federation without +// the usual context a matrix room event would have. We usually do not have +// access to the events needed to check the event auth rules for the invite. +type InputInviteEvent struct { + Event gomatrixserverlib.Event `json:"event"` +} + // InputRoomEventsRequest is a request to InputRoomEvents type InputRoomEventsRequest struct { - InputRoomEvents []InputRoomEvent `json:"input_room_events"` + InputRoomEvents []InputRoomEvent `json:"input_room_events"` + InputInviteEvents []InputInviteEvent `json:"input_invite_events"` } // InputRoomEventsResponse is a response to InputRoomEvents diff --git a/src/github.com/matrix-org/dendrite/roomserver/api/query.go b/src/github.com/matrix-org/dendrite/roomserver/api/query.go index 6e6a838a9..f07da59e5 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/api/query.go +++ b/src/github.com/matrix-org/dendrite/roomserver/api/query.go @@ -100,6 +100,23 @@ type QueryEventsByIDResponse struct { Events []gomatrixserverlib.Event `json:"events"` } +// QueryMembershipsForRoomRequest is a request to QueryMembershipsForRoom +type QueryMembershipsForRoomRequest struct { + // ID of the room to fetch memberships from + RoomID string `json:"room_id"` + // ID of the user sending the request + Sender string `json:"sender"` +} + +// QueryMembershipsForRoomResponse is a response to QueryMembershipsForRoom +type QueryMembershipsForRoomResponse struct { + // The "m.room.member" events (of "join" membership) in the client format + JoinEvents []gomatrixserverlib.ClientEvent `json:"join_events"` + // True if the user has been in room before and has either stayed in it or + // left it. + HasBeenInRoom bool `json:"has_been_in_room"` +} + // RoomserverQueryAPI is used to query information from the room server. type RoomserverQueryAPI interface { // Query the latest events and state for a room from the room server. @@ -119,6 +136,12 @@ type RoomserverQueryAPI interface { request *QueryEventsByIDRequest, response *QueryEventsByIDResponse, ) error + + // Query a list of membership events for a room + QueryMembershipsForRoom( + request *QueryMembershipsForRoomRequest, + response *QueryMembershipsForRoomResponse, + ) error } // RoomserverQueryLatestEventsAndStatePath is the HTTP path for the QueryLatestEventsAndState API. @@ -130,6 +153,9 @@ const RoomserverQueryStateAfterEventsPath = "/api/roomserver/queryStateAfterEven // RoomserverQueryEventsByIDPath is the HTTP path for the QueryEventsByID API. const RoomserverQueryEventsByIDPath = "/api/roomserver/queryEventsByID" +// RoomserverQueryMembershipsForRoomPath is the HTTP path for the QueryMembershipsForRoom API +const RoomserverQueryMembershipsForRoomPath = "/api/roomserver/queryMembershipsForRoom" + // NewRoomserverQueryAPIHTTP creates a RoomserverQueryAPI implemented by talking to a HTTP POST API. // If httpClient is nil then it uses the http.DefaultClient func NewRoomserverQueryAPIHTTP(roomserverURL string, httpClient *http.Client) RoomserverQueryAPI { @@ -171,6 +197,15 @@ func (h *httpRoomserverQueryAPI) QueryEventsByID( return postJSON(h.httpClient, apiURL, request, response) } +// QueryMembershipsForRoom implements RoomserverQueryAPI +func (h *httpRoomserverQueryAPI) QueryMembershipsForRoom( + request *QueryMembershipsForRoomRequest, + response *QueryMembershipsForRoomResponse, +) error { + apiURL := h.roomserverURL + RoomserverQueryMembershipsForRoomPath + return postJSON(h.httpClient, apiURL, request, response) +} + func postJSON(httpClient *http.Client, apiURL string, request, response interface{}) error { jsonBytes, err := json.Marshal(request) if err != nil { diff --git a/src/github.com/matrix-org/dendrite/roomserver/input/events.go b/src/github.com/matrix-org/dendrite/roomserver/input/events.go index c1eee4c96..82b4652e6 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/input/events.go +++ b/src/github.com/matrix-org/dendrite/roomserver/input/events.go @@ -15,6 +15,9 @@ package input import ( + "fmt" + + "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/types" @@ -39,6 +42,8 @@ type RoomEventDatabase interface { GetLatestEventsForUpdate(roomNID types.RoomNID) (updater types.RoomRecentEventsUpdater, err error) // Lookup the string event IDs for a list of numeric event IDs EventIDs(eventNIDs []types.EventNID) (map[types.EventNID]string, error) + // Build a membership updater for the target user in a room. + MembershipUpdater(roomID, targerUserID string) (types.MembershipUpdater, error) } // OutputRoomEventWriter has the APIs needed to write an event to the output logs. @@ -103,13 +108,64 @@ func processRoomEvent(db RoomEventDatabase, ow OutputRoomEventWriter, input api. return err } - // TODO: - // * Caculate the new current state for the room if the forward extremities have changed. - // * Work out the delta between the new current state and the previous current state. - // * Work out the visibility of the event. - // * Write a message to the output logs containing: - // - The event itself - // - The visiblity of the event, i.e. who is allowed to see the event. - // - The changes to the current state of the room. + return nil +} + +func processInviteEvent(db RoomEventDatabase, ow OutputRoomEventWriter, input api.InputInviteEvent) (err error) { + if input.Event.StateKey() == nil { + return fmt.Errorf("invite must be a state event") + } + + roomID := input.Event.RoomID() + targetUserID := *input.Event.StateKey() + + updater, err := db.MembershipUpdater(roomID, targetUserID) + if err != nil { + return err + } + succeeded := false + defer common.EndTransaction(updater, &succeeded) + + if updater.IsJoin() { + // If the user is joined to the room then that takes precedence over this + // invite event. It makes little sense to move a user that is already + // joined to the room into the invite state. + // This could plausibly happen if an invite request raced with a join + // request for a user. For example if a user was invited to a public + // room and they joined the room at the same time as the invite was sent. + // The other way this could plausibly happen is if an invite raced with + // a kick. For example if a user was kicked from a room in error and in + // response someone else in the room re-invited them then it is possible + // for the invite request to race with the leave event so that the + // target receives invite before it learns that it has been kicked. + // There are a few ways this could be plausibly handled in the roomserver. + // 1) Store the invite, but mark it as retired. That will result in the + // permanent rejection of that invite event. So even if the target + // user leaves the room and the invite is retransmitted it will be + // ignored. However a new invite with a new event ID would still be + // accepted. + // 2) Silently discard the invite event. This means that if the event + // was retransmitted at a later date after the target user had left + // the room we would accept the invite. However since we hadn't told + // the sending server that the invite had been discarded it would + // have no reason to attempt to retry. + // 3) Signal the sending server that the user is already joined to the + // room. + // For now we will implement option 2. Since in the abesence of a retry + // mechanism it will be equivalent to option 1, and we don't have a + // signalling mechanism to implement option 3. + return nil + } + + outputUpdates, err := updateToInviteMembership(updater, &input.Event, nil) + if err != nil { + return err + } + + if err = ow.WriteOutputEvents(roomID, outputUpdates); err != nil { + return err + } + + succeeded = true return nil } diff --git a/src/github.com/matrix-org/dendrite/roomserver/input/input.go b/src/github.com/matrix-org/dendrite/roomserver/input/input.go index 210abfa29..17e94599e 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/input/input.go +++ b/src/github.com/matrix-org/dendrite/roomserver/input/input.go @@ -61,6 +61,11 @@ func (r *RoomserverInputAPI) InputRoomEvents( return err } } + for i := range request.InputInviteEvents { + if err := processInviteEvent(r.DB, r, request.InputInviteEvents[i]); err != nil { + return err + } + } return nil } diff --git a/src/github.com/matrix-org/dendrite/roomserver/input/latest_events.go b/src/github.com/matrix-org/dendrite/roomserver/input/latest_events.go index 9328ecf3b..d9aa2b455 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/input/latest_events.go +++ b/src/github.com/matrix-org/dendrite/roomserver/input/latest_events.go @@ -17,6 +17,7 @@ package input import ( "bytes" + "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/types" @@ -52,25 +53,19 @@ func updateLatestEvents( if err != nil { return } - defer func() { - if err == nil { - // Commit if there wasn't an error. - // Set the returned err value if we encounter an error committing. - // This only works because err is a named return. - err = updater.Commit() - } else { - // Ignore any error we get rolling back since we don't want to - // clobber the current error - // TODO: log the error here. - updater.Rollback() - } - }() + succeeded := false + defer common.EndTransaction(updater, &succeeded) u := latestEventsUpdater{ db: db, updater: updater, ow: ow, roomNID: roomNID, stateAtEvent: stateAtEvent, event: event, sendAsServer: sendAsServer, } - return u.doUpdateLatestEvents() + if err = u.doUpdateLatestEvents(); err != nil { + return err + } + + succeeded = true + return } // latestEventsUpdater tracks the state used to update the latest events in the diff --git a/src/github.com/matrix-org/dendrite/roomserver/input/membership.go b/src/github.com/matrix-org/dendrite/roomserver/input/membership.go index f306697ff..6eeb0914d 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/input/membership.go +++ b/src/github.com/matrix-org/dendrite/roomserver/input/membership.go @@ -95,10 +95,9 @@ func updateMembership( return nil, err } } - if old == new { + if old == new && new != "join" { // If the membership is the same then nothing changed and we can return - // immediately. This should help speed up processing for display name - // changes where the membership is "join" both before and after. + // immediately, unless it's a "join" update (e.g. profile update). return updates, nil } @@ -152,16 +151,21 @@ func updateToInviteMembership( func updateToJoinMembership( mu types.MembershipUpdater, add *gomatrixserverlib.Event, updates []api.OutputEvent, ) ([]api.OutputEvent, error) { - // If the user is already marked as being joined then we can return immediately. - // TODO: Is this code reachable given the "old != new" guard in updateMembership? + // If the user is already marked as being joined, we call SetToJoin to update + // the event ID then we can return immediately. Retired is ignored as there + // is no invite event to retire. if mu.IsJoin() { + _, err := mu.SetToJoin(add.Sender(), add.EventID(), true) + if err != nil { + return nil, err + } return updates, nil } // When we mark a user as being joined we will invalidate any invites that // are active for that user. We notify the consumers that the invites have // been retired using a special event, even though they could infer this // by studying the state changes in the room event stream. - retired, err := mu.SetToJoin(add.Sender()) + retired, err := mu.SetToJoin(add.Sender(), add.EventID(), false) if err != nil { return nil, err } @@ -194,7 +198,7 @@ func updateToLeaveMembership( // are active for that user. We notify the consumers that the invites have // been retired using a special event, even though they could infer this // by studying the state changes in the room event stream. - retired, err := mu.SetToLeave(add.Sender()) + retired, err := mu.SetToLeave(add.Sender(), add.EventID()) if err != nil { return nil, err } diff --git a/src/github.com/matrix-org/dendrite/roomserver/query/query.go b/src/github.com/matrix-org/dendrite/roomserver/query/query.go index 30b695fb4..84f5d44c3 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/query/query.go +++ b/src/github.com/matrix-org/dendrite/roomserver/query/query.go @@ -52,6 +52,13 @@ type RoomserverQueryAPIDatabase interface { // Remove a given room alias. // Returns an error if there was a problem talking to the database. RemoveRoomAlias(alias string) error + // Lookup the join events for all members in a room as requested by a given + // user. If the user is currently in the room, returns the room's current + // members, if not returns an empty array (TODO: Fix it) + // If the user requesting the list of members has never been in the room, + // returns nil. + // If there was an issue retrieving the events, returns an error. + GetMembershipEvents(roomNID types.RoomNID, requestSenderUserID string) (events []types.Event, err error) } // RoomserverQueryAPI is an implementation of api.RoomserverQueryAPI @@ -182,6 +189,37 @@ func (r *RoomserverQueryAPI) loadEvents(eventNIDs []types.EventNID) ([]gomatrixs return result, nil } +// QueryMembershipsForRoom implements api.RoomserverQueryAPI +func (r *RoomserverQueryAPI) QueryMembershipsForRoom( + request *api.QueryMembershipsForRoomRequest, + response *api.QueryMembershipsForRoomResponse, +) error { + roomNID, err := r.DB.RoomNID(request.RoomID) + if err != nil { + return err + } + + events, err := r.DB.GetMembershipEvents(roomNID, request.Sender) + if err != nil { + return nil + } + + if events == nil { + response.HasBeenInRoom = false + response.JoinEvents = nil + return nil + } + + response.HasBeenInRoom = true + response.JoinEvents = []gomatrixserverlib.ClientEvent{} + for _, event := range events { + clientEvent := gomatrixserverlib.ToClientEvent(event.Event, gomatrixserverlib.FormatAll) + response.JoinEvents = append(response.JoinEvents, clientEvent) + } + + return nil +} + // SetupHTTP adds the RoomserverQueryAPI handlers to the http.ServeMux. func (r *RoomserverQueryAPI) SetupHTTP(servMux *http.ServeMux) { servMux.Handle( @@ -226,4 +264,18 @@ func (r *RoomserverQueryAPI) SetupHTTP(servMux *http.ServeMux) { return util.JSONResponse{Code: 200, JSON: &response} }), ) + servMux.Handle( + api.RoomserverQueryMembershipsForRoomPath, + common.MakeAPI("queryMembershipsForRoom", func(req *http.Request) util.JSONResponse { + var request api.QueryMembershipsForRoomRequest + var response api.QueryMembershipsForRoomResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.ErrorResponse(err) + } + if err := r.QueryMembershipsForRoom(&request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: 200, JSON: &response} + }), + ) } diff --git a/src/github.com/matrix-org/dendrite/roomserver/storage/event_state_keys_table.go b/src/github.com/matrix-org/dendrite/roomserver/storage/event_state_keys_table.go index d30e45815..b06f5b2a5 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/storage/event_state_keys_table.go +++ b/src/github.com/matrix-org/dendrite/roomserver/storage/event_state_keys_table.go @@ -18,6 +18,7 @@ import ( "database/sql" "github.com/lib/pq" + "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -58,10 +59,22 @@ const bulkSelectEventStateKeyNIDSQL = "" + "SELECT event_state_key, event_state_key_nid FROM roomserver_event_state_keys" + " WHERE event_state_key = ANY($1)" +const selectEventStateKeySQL = "" + + "SELECT event_state_key FROM roomserver_event_state_keys" + + " WHERE event_state_key_nid = $1" + +// Bulk lookup from numeric ID to string state key for that state key. +// Takes an array of strings as the query parameter. +const bulkSelectEventStateKeySQL = "" + + "SELECT event_state_key, event_state_key_nid FROM roomserver_event_state_keys" + + " WHERE event_state_key_nid = ANY($1)" + type eventStateKeyStatements struct { insertEventStateKeyNIDStmt *sql.Stmt selectEventStateKeyNIDStmt *sql.Stmt + selectEventStateKeyStmt *sql.Stmt bulkSelectEventStateKeyNIDStmt *sql.Stmt + bulkSelectEventStateKeyStmt *sql.Stmt } func (s *eventStateKeyStatements) prepare(db *sql.DB) (err error) { @@ -72,27 +85,21 @@ func (s *eventStateKeyStatements) prepare(db *sql.DB) (err error) { return statementList{ {&s.insertEventStateKeyNIDStmt, insertEventStateKeyNIDSQL}, {&s.selectEventStateKeyNIDStmt, selectEventStateKeyNIDSQL}, + {&s.selectEventStateKeyStmt, selectEventStateKeySQL}, {&s.bulkSelectEventStateKeyNIDStmt, bulkSelectEventStateKeyNIDSQL}, + {&s.bulkSelectEventStateKeyStmt, bulkSelectEventStateKeySQL}, }.prepare(db) } func (s *eventStateKeyStatements) insertEventStateKeyNID(txn *sql.Tx, eventStateKey string) (types.EventStateKeyNID, error) { var eventStateKeyNID int64 - stmt := s.insertEventStateKeyNIDStmt - if txn != nil { - stmt = txn.Stmt(stmt) - } - err := stmt.QueryRow(eventStateKey).Scan(&eventStateKeyNID) + err := common.TxStmt(txn, s.insertEventStateKeyNIDStmt).QueryRow(eventStateKey).Scan(&eventStateKeyNID) return types.EventStateKeyNID(eventStateKeyNID), err } func (s *eventStateKeyStatements) selectEventStateKeyNID(txn *sql.Tx, eventStateKey string) (types.EventStateKeyNID, error) { var eventStateKeyNID int64 - stmt := s.selectEventStateKeyNIDStmt - if txn != nil { - stmt = txn.Stmt(stmt) - } - err := stmt.QueryRow(eventStateKey).Scan(&eventStateKeyNID) + err := common.TxStmt(txn, s.selectEventStateKeyNIDStmt).QueryRow(eventStateKey).Scan(&eventStateKeyNID) return types.EventStateKeyNID(eventStateKeyNID), err } @@ -114,3 +121,32 @@ func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID(eventStateKeys []st } return result, nil } + +func (s *eventStateKeyStatements) selectEventStateKey(txn *sql.Tx, eventStateKeyNID types.EventStateKeyNID) (string, error) { + var eventStateKey string + err := common.TxStmt(txn, s.selectEventStateKeyStmt).QueryRow(eventStateKeyNID).Scan(&eventStateKey) + return eventStateKey, err +} + +func (s *eventStateKeyStatements) bulkSelectEventStateKey(eventStateKeyNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error) { + var nIDs pq.Int64Array + for i := range eventStateKeyNIDs { + nIDs[i] = int64(eventStateKeyNIDs[i]) + } + rows, err := s.bulkSelectEventStateKeyStmt.Query(nIDs) + if err != nil { + return nil, err + } + defer rows.Close() + + result := make(map[types.EventStateKeyNID]string, len(eventStateKeyNIDs)) + for rows.Next() { + var stateKey string + var stateKeyNID int64 + if err := rows.Scan(&stateKey, &stateKeyNID); err != nil { + return nil, err + } + result[types.EventStateKeyNID(stateKeyNID)] = stateKey + } + return result, nil +} diff --git a/src/github.com/matrix-org/dendrite/roomserver/storage/events_table.go b/src/github.com/matrix-org/dendrite/roomserver/storage/events_table.go index b6db15c82..2d2b85625 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/storage/events_table.go +++ b/src/github.com/matrix-org/dendrite/roomserver/storage/events_table.go @@ -19,6 +19,7 @@ import ( "fmt" "github.com/lib/pq" + "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" ) @@ -253,22 +254,22 @@ func (s *eventStatements) updateEventState(eventNID types.EventNID, stateNID typ } func (s *eventStatements) selectEventSentToOutput(txn *sql.Tx, eventNID types.EventNID) (sentToOutput bool, err error) { - err = txn.Stmt(s.selectEventSentToOutputStmt).QueryRow(int64(eventNID)).Scan(&sentToOutput) + err = common.TxStmt(txn, s.selectEventSentToOutputStmt).QueryRow(int64(eventNID)).Scan(&sentToOutput) return } func (s *eventStatements) updateEventSentToOutput(txn *sql.Tx, eventNID types.EventNID) error { - _, err := txn.Stmt(s.updateEventSentToOutputStmt).Exec(int64(eventNID)) + _, err := common.TxStmt(txn, s.updateEventSentToOutputStmt).Exec(int64(eventNID)) return err } func (s *eventStatements) selectEventID(txn *sql.Tx, eventNID types.EventNID) (eventID string, err error) { - err = txn.Stmt(s.selectEventIDStmt).QueryRow(int64(eventNID)).Scan(&eventID) + err = common.TxStmt(txn, s.selectEventIDStmt).QueryRow(int64(eventNID)).Scan(&eventID) return } func (s *eventStatements) bulkSelectStateAtEventAndReference(txn *sql.Tx, eventNIDs []types.EventNID) ([]types.StateAtEventAndReference, error) { - rows, err := txn.Stmt(s.bulkSelectStateAtEventAndReferenceStmt).Query(eventNIDsAsArray(eventNIDs)) + rows, err := common.TxStmt(txn, s.bulkSelectStateAtEventAndReferenceStmt).Query(eventNIDsAsArray(eventNIDs)) if err != nil { return nil, err } diff --git a/src/github.com/matrix-org/dendrite/roomserver/storage/invite_table.go b/src/github.com/matrix-org/dendrite/roomserver/storage/invite_table.go index 9e0860b42..8bae2b781 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/storage/invite_table.go +++ b/src/github.com/matrix-org/dendrite/roomserver/storage/invite_table.go @@ -17,6 +17,7 @@ package storage import ( "database/sql" + "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -94,7 +95,7 @@ func (s *inviteStatements) insertInviteEvent( targetUserNID, senderUserNID types.EventStateKeyNID, inviteEventJSON []byte, ) (bool, error) { - result, err := txn.Stmt(s.insertInviteEventStmt).Exec( + result, err := common.TxStmt(txn, s.insertInviteEventStmt).Exec( inviteEventID, roomNID, targetUserNID, senderUserNID, inviteEventJSON, ) if err != nil { @@ -110,7 +111,7 @@ func (s *inviteStatements) insertInviteEvent( func (s *inviteStatements) updateInviteRetired( txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, ) ([]string, error) { - rows, err := txn.Stmt(s.updateInviteRetiredStmt).Query(roomNID, targetUserNID) + rows, err := common.TxStmt(txn, s.updateInviteRetiredStmt).Query(roomNID, targetUserNID) if err != nil { return nil, err } diff --git a/src/github.com/matrix-org/dendrite/roomserver/storage/membership_table.go b/src/github.com/matrix-org/dendrite/roomserver/storage/membership_table.go index 725e5b8d9..6edc7a528 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/storage/membership_table.go +++ b/src/github.com/matrix-org/dendrite/roomserver/storage/membership_table.go @@ -17,6 +17,7 @@ package storage import ( "database/sql" + "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -33,7 +34,7 @@ const membershipSchema = ` -- and the room state tables. -- This table is updated in one of 3 ways: -- 1) The membership of a user changes within the current state of the room. --- 2) An invite is received outside of a room over federation. +-- 2) An invite is received outside of a room over federation. -- 3) An invite is rejected outside of a room over federation. CREATE TABLE IF NOT EXISTS roomserver_membership ( room_nid BIGINT NOT NULL, @@ -46,6 +47,16 @@ CREATE TABLE IF NOT EXISTS roomserver_membership ( -- The state the user is in within this room. -- Default value is "membershipStateLeaveOrBan" membership_nid BIGINT NOT NULL DEFAULT 1, + -- The numeric ID of the membership event. + -- It refers to the join membership event if the membership_nid is join (3), + -- and to the leave/ban membership event if the membership_nid is leave or + -- ban (1). + -- If the membership_nid is invite (2) and the user has been in the room + -- before, it will refer to the previous leave/ban membership event, and will + -- be equals to 0 (its default) if the user never joined the room before. + -- This NID is updated if the join event gets updated (e.g. profile update), + -- or if the user leaves/joins the room. + event_nid BIGINT NOT NULL DEFAULT 0, UNIQUE (room_nid, target_nid) ); ` @@ -57,18 +68,33 @@ const insertMembershipSQL = "" + " VALUES ($1, $2)" + " ON CONFLICT DO NOTHING" +const selectMembershipFromRoomAndTargetSQL = "" + + "SELECT membership_nid, event_nid FROM roomserver_membership" + + " WHERE room_nid = $1 AND target_nid = $2" + +const selectMembershipsFromRoomAndMembershipSQL = "" + + "SELECT event_nid FROM roomserver_membership" + + " WHERE room_nid = $1 AND membership_nid = $2" + +const selectMembershipsFromRoomSQL = "" + + "SELECT membership_nid, event_nid FROM roomserver_membership" + + " WHERE room_nid = $1" + const selectMembershipForUpdateSQL = "" + "SELECT membership_nid FROM roomserver_membership" + " WHERE room_nid = $1 AND target_nid = $2 FOR UPDATE" const updateMembershipSQL = "" + - "UPDATE roomserver_membership SET sender_nid = $3, membership_nid = $4" + + "UPDATE roomserver_membership SET sender_nid = $3, membership_nid = $4, event_nid = $5" + " WHERE room_nid = $1 AND target_nid = $2" type membershipStatements struct { - insertMembershipStmt *sql.Stmt - selectMembershipForUpdateStmt *sql.Stmt - updateMembershipStmt *sql.Stmt + insertMembershipStmt *sql.Stmt + selectMembershipForUpdateStmt *sql.Stmt + selectMembershipFromRoomAndTargetStmt *sql.Stmt + selectMembershipsFromRoomAndMembershipStmt *sql.Stmt + selectMembershipsFromRoomStmt *sql.Stmt + updateMembershipStmt *sql.Stmt } func (s *membershipStatements) prepare(db *sql.DB) (err error) { @@ -80,6 +106,9 @@ func (s *membershipStatements) prepare(db *sql.DB) (err error) { return statementList{ {&s.insertMembershipStmt, insertMembershipSQL}, {&s.selectMembershipForUpdateStmt, selectMembershipForUpdateSQL}, + {&s.selectMembershipFromRoomAndTargetStmt, selectMembershipFromRoomAndTargetSQL}, + {&s.selectMembershipsFromRoomAndMembershipStmt, selectMembershipsFromRoomAndMembershipSQL}, + {&s.selectMembershipsFromRoomStmt, selectMembershipsFromRoomSQL}, {&s.updateMembershipStmt, updateMembershipSQL}, }.prepare(db) } @@ -87,25 +116,72 @@ func (s *membershipStatements) prepare(db *sql.DB) (err error) { func (s *membershipStatements) insertMembership( txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, ) error { - _, err := txn.Stmt(s.insertMembershipStmt).Exec(roomNID, targetUserNID) + _, err := common.TxStmt(txn, s.insertMembershipStmt).Exec(roomNID, targetUserNID) return err } func (s *membershipStatements) selectMembershipForUpdate( txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, ) (membership membershipState, err error) { - err = txn.Stmt(s.selectMembershipForUpdateStmt).QueryRow( + err = common.TxStmt(txn, s.selectMembershipForUpdateStmt).QueryRow( roomNID, targetUserNID, ).Scan(&membership) return } +func (s *membershipStatements) selectMembershipFromRoomAndTarget( + roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, +) (eventNID types.EventNID, membership membershipState, err error) { + err = s.selectMembershipFromRoomAndTargetStmt.QueryRow( + roomNID, targetUserNID, + ).Scan(&membership, &eventNID) + return +} + +func (s *membershipStatements) selectMembershipsFromRoom( + roomNID types.RoomNID, +) (eventNIDs map[types.EventNID]membershipState, err error) { + rows, err := s.selectMembershipsFromRoomStmt.Query(roomNID) + if err != nil { + return + } + + eventNIDs = make(map[types.EventNID]membershipState) + for rows.Next() { + var eNID types.EventNID + var membership membershipState + if err = rows.Scan(&membership, &eNID); err != nil { + return + } + eventNIDs[eNID] = membership + } + return +} +func (s *membershipStatements) selectMembershipsFromRoomAndMembership( + roomNID types.RoomNID, membership membershipState, +) (eventNIDs []types.EventNID, err error) { + rows, err := s.selectMembershipsFromRoomAndMembershipStmt.Query(roomNID, membership) + if err != nil { + return + } + + for rows.Next() { + var eNID types.EventNID + if err = rows.Scan(&eNID); err != nil { + return + } + eventNIDs = append(eventNIDs, eNID) + } + return +} + func (s *membershipStatements) updateMembership( txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership membershipState, + eventNID types.EventNID, ) error { - _, err := txn.Stmt(s.updateMembershipStmt).Exec( - roomNID, targetUserNID, senderUserNID, membership, + _, err := common.TxStmt(txn, s.updateMembershipStmt).Exec( + roomNID, targetUserNID, senderUserNID, membership, eventNID, ) return err } diff --git a/src/github.com/matrix-org/dendrite/roomserver/storage/previous_events_table.go b/src/github.com/matrix-org/dendrite/roomserver/storage/previous_events_table.go index 71795d488..9fcf1cb5c 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/storage/previous_events_table.go +++ b/src/github.com/matrix-org/dendrite/roomserver/storage/previous_events_table.go @@ -17,6 +17,7 @@ package storage import ( "database/sql" + "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -73,7 +74,7 @@ func (s *previousEventStatements) prepare(db *sql.DB) (err error) { } func (s *previousEventStatements) insertPreviousEvent(txn *sql.Tx, previousEventID string, previousEventReferenceSHA256 []byte, eventNID types.EventNID) error { - _, err := txn.Stmt(s.insertPreviousEventStmt).Exec(previousEventID, previousEventReferenceSHA256, int64(eventNID)) + _, err := common.TxStmt(txn, s.insertPreviousEventStmt).Exec(previousEventID, previousEventReferenceSHA256, int64(eventNID)) return err } @@ -81,5 +82,5 @@ func (s *previousEventStatements) insertPreviousEvent(txn *sql.Tx, previousEvent // Returns sql.ErrNoRows if the event reference doesn't exist. func (s *previousEventStatements) selectPreviousEventExists(txn *sql.Tx, eventID string, eventReferenceSHA256 []byte) error { var ok int64 - return txn.Stmt(s.selectPreviousEventExistsStmt).QueryRow(eventID, eventReferenceSHA256).Scan(&ok) + return common.TxStmt(txn, s.selectPreviousEventExistsStmt).QueryRow(eventID, eventReferenceSHA256).Scan(&ok) } diff --git a/src/github.com/matrix-org/dendrite/roomserver/storage/rooms_table.go b/src/github.com/matrix-org/dendrite/roomserver/storage/rooms_table.go index 03cacd7db..4ba329f39 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/storage/rooms_table.go +++ b/src/github.com/matrix-org/dendrite/roomserver/storage/rooms_table.go @@ -18,6 +18,7 @@ import ( "database/sql" "github.com/lib/pq" + "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -80,15 +81,15 @@ func (s *roomStatements) prepare(db *sql.DB) (err error) { }.prepare(db) } -func (s *roomStatements) insertRoomNID(roomID string) (types.RoomNID, error) { +func (s *roomStatements) insertRoomNID(txn *sql.Tx, roomID string) (types.RoomNID, error) { var roomNID int64 - err := s.insertRoomNIDStmt.QueryRow(roomID).Scan(&roomNID) + err := common.TxStmt(txn, s.insertRoomNIDStmt).QueryRow(roomID).Scan(&roomNID) return types.RoomNID(roomNID), err } -func (s *roomStatements) selectRoomNID(roomID string) (types.RoomNID, error) { +func (s *roomStatements) selectRoomNID(txn *sql.Tx, roomID string) (types.RoomNID, error) { var roomNID int64 - err := s.selectRoomNIDStmt.QueryRow(roomID).Scan(&roomNID) + err := common.TxStmt(txn, s.selectRoomNIDStmt).QueryRow(roomID).Scan(&roomNID) return types.RoomNID(roomNID), err } @@ -112,7 +113,7 @@ func (s *roomStatements) selectLatestEventsNIDsForUpdate(txn *sql.Tx, roomNID ty var nids pq.Int64Array var lastEventSentNID int64 var stateSnapshotNID int64 - err := txn.Stmt(s.selectLatestEventNIDsForUpdateStmt).QueryRow(int64(roomNID)).Scan(&nids, &lastEventSentNID, &stateSnapshotNID) + err := common.TxStmt(txn, s.selectLatestEventNIDsForUpdateStmt).QueryRow(int64(roomNID)).Scan(&nids, &lastEventSentNID, &stateSnapshotNID) if err != nil { return nil, 0, 0, err } @@ -127,7 +128,7 @@ func (s *roomStatements) updateLatestEventNIDs( txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID, stateSnapshotNID types.StateSnapshotNID, ) error { - _, err := txn.Stmt(s.updateLatestEventNIDsStmt).Exec( + _, err := common.TxStmt(txn, s.updateLatestEventNIDsStmt).Exec( roomNID, eventNIDsAsArray(eventNIDs), int64(lastEventSentNID), int64(stateSnapshotNID), ) return err diff --git a/src/github.com/matrix-org/dendrite/roomserver/storage/storage.go b/src/github.com/matrix-org/dendrite/roomserver/storage/storage.go index d323fd139..fbbc723ee 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/storage/storage.go +++ b/src/github.com/matrix-org/dendrite/roomserver/storage/storage.go @@ -53,7 +53,7 @@ func (d *Database) StoreEvent(event gomatrixserverlib.Event, authEventNIDs []typ err error ) - if roomNID, err = d.assignRoomNID(event.RoomID()); err != nil { + if roomNID, err = d.assignRoomNID(nil, event.RoomID()); err != nil { return 0, types.StateAtEvent{}, err } @@ -104,15 +104,15 @@ func (d *Database) StoreEvent(event gomatrixserverlib.Event, authEventNIDs []typ }, nil } -func (d *Database) assignRoomNID(roomID string) (types.RoomNID, error) { +func (d *Database) assignRoomNID(txn *sql.Tx, roomID string) (types.RoomNID, error) { // Check if we already have a numeric ID in the database. - roomNID, err := d.statements.selectRoomNID(roomID) + roomNID, err := d.statements.selectRoomNID(txn, roomID) if err == sql.ErrNoRows { // We don't have a numeric ID so insert one into the database. - roomNID, err = d.statements.insertRoomNID(roomID) + roomNID, err = d.statements.insertRoomNID(txn, roomID) if err == sql.ErrNoRows { // We raced with another insert so run the select again. - roomNID, err = d.statements.selectRoomNID(roomID) + roomNID, err = d.statements.selectRoomNID(txn, roomID) } } return roomNID, err @@ -329,7 +329,7 @@ func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventSta // RoomNID implements query.RoomserverQueryAPIDB func (d *Database) RoomNID(roomID string) (types.RoomNID, error) { - roomNID, err := d.statements.selectRoomNID(roomID) + roomNID, err := d.statements.selectRoomNID(nil, roomID) if err == sql.ErrNoRows { return 0, nil } @@ -380,6 +380,38 @@ func (d *Database) StateEntriesForTuples( return d.statements.bulkSelectFilteredStateBlockEntries(stateBlockNIDs, stateKeyTuples) } +// MembershipUpdater implements input.RoomEventDatabase +func (d *Database) MembershipUpdater(roomID, targetUserID string) (types.MembershipUpdater, error) { + txn, err := d.db.Begin() + if err != nil { + return nil, err + } + succeeded := false + defer func() { + if !succeeded { + txn.Rollback() + } + }() + + roomNID, err := d.assignRoomNID(txn, roomID) + if err != nil { + return nil, err + } + + targetUserNID, err := d.assignStateKeyNID(txn, targetUserID) + if err != nil { + return nil, err + } + + updater, err := d.membershipUpdaterTxn(txn, roomNID, targetUserNID) + if err != nil { + return nil, err + } + + succeeded = true + return updater, nil +} + type membershipUpdater struct { transaction d *Database @@ -435,7 +467,7 @@ func (u *membershipUpdater) SetToInvite(event gomatrixserverlib.Event) (bool, er } if u.membership != membershipStateInvite { if err = u.d.statements.updateMembership( - u.txn, u.roomNID, u.targetUserNID, senderUserNID, membershipStateInvite, + u.txn, u.roomNID, u.targetUserNID, senderUserNID, membershipStateInvite, 0, ); err != nil { return false, err } @@ -444,7 +476,43 @@ func (u *membershipUpdater) SetToInvite(event gomatrixserverlib.Event) (bool, er } // SetToJoin implements types.MembershipUpdater -func (u *membershipUpdater) SetToJoin(senderUserID string) ([]string, error) { +func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpdate bool) ([]string, error) { + var inviteEventIDs []string + + senderUserNID, err := u.d.assignStateKeyNID(u.txn, senderUserID) + if err != nil { + return nil, err + } + + // If this is a join event update, there is no invite to update + if !isUpdate { + inviteEventIDs, err = u.d.statements.updateInviteRetired( + u.txn, u.roomNID, u.targetUserNID, + ) + if err != nil { + return nil, err + } + } + + // Lookup the NID of the new join event + nIDs, err := u.d.EventNIDs([]string{eventID}) + if err != nil { + return nil, err + } + + if u.membership != membershipStateJoin || isUpdate { + if err = u.d.statements.updateMembership( + u.txn, u.roomNID, u.targetUserNID, senderUserNID, membershipStateJoin, nIDs[eventID], + ); err != nil { + return nil, err + } + } + + return inviteEventIDs, nil +} + +// SetToLeave implements types.MembershipUpdater +func (u *membershipUpdater) SetToLeave(senderUserID string, eventID string) ([]string, error) { senderUserNID, err := u.d.assignStateKeyNID(u.txn, senderUserID) if err != nil { return nil, err @@ -455,9 +523,16 @@ func (u *membershipUpdater) SetToJoin(senderUserID string) ([]string, error) { if err != nil { return nil, err } - if u.membership != membershipStateJoin { + + // Lookup the NID of the new leave event + nIDs, err := u.d.EventNIDs([]string{eventID}) + if err != nil { + return nil, err + } + + if u.membership != membershipStateLeaveOrBan { if err = u.d.statements.updateMembership( - u.txn, u.roomNID, u.targetUserNID, senderUserNID, membershipStateJoin, + u.txn, u.roomNID, u.targetUserNID, senderUserNID, membershipStateLeaveOrBan, nIDs[eventID], ); err != nil { return nil, err } @@ -465,26 +540,49 @@ func (u *membershipUpdater) SetToJoin(senderUserID string) ([]string, error) { return inviteEventIDs, nil } -// SetToLeave implements types.MembershipUpdater -func (u *membershipUpdater) SetToLeave(senderUserID string) ([]string, error) { - senderUserNID, err := u.d.assignStateKeyNID(u.txn, senderUserID) +// GetMembershipEvents implements query.RoomserverQueryAPIDB +func (d *Database) GetMembershipEvents(roomNID types.RoomNID, requestSenderUserID string) (events []types.Event, err error) { + txn, err := d.db.Begin() if err != nil { - return nil, err + return } - inviteEventIDs, err := u.d.statements.updateInviteRetired( - u.txn, u.roomNID, u.targetUserNID, - ) + defer txn.Commit() + + requestSenderUserNID, err := d.assignStateKeyNID(txn, requestSenderUserID) if err != nil { - return nil, err + return } - if u.membership != membershipStateLeaveOrBan { - if err = u.d.statements.updateMembership( - u.txn, u.roomNID, u.targetUserNID, senderUserNID, membershipStateLeaveOrBan, - ); err != nil { + + _, senderMembership, err := d.statements.selectMembershipFromRoomAndTarget(roomNID, requestSenderUserNID) + if err == sql.ErrNoRows { + // The user has never been a member of that room + return nil, nil + } else if err != nil { + return + } + + if senderMembership == membershipStateJoin { + // The user is still in the room: Send the current list of joined members + var joinEventNIDs []types.EventNID + joinEventNIDs, err = d.statements.selectMembershipsFromRoomAndMembership(roomNID, membershipStateJoin) + if err != nil { return nil, err } + + events, err = d.Events(joinEventNIDs) + } else { + // The user isn't in the room anymore + // TODO: Send the list of joined member as it was when the user left + // We cannot do this using only the memberships database, as it + // only stores the latest join event NID for a given target user. + // The solution would be to build the state of a room after before + // the leave event and extract a members list from it. + // For now, we return an empty slice so we know the user has been + // in the room before. + events = []types.Event{} } - return inviteEventIDs, nil + + return } type transaction struct { diff --git a/src/github.com/matrix-org/dendrite/roomserver/types/types.go b/src/github.com/matrix-org/dendrite/roomserver/types/types.go index 809b6e574..d5fe32762 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/types/types.go +++ b/src/github.com/matrix-org/dendrite/roomserver/types/types.go @@ -16,6 +16,7 @@ package types import ( + "github.com/matrix-org/dendrite/common" "github.com/matrix-org/gomatrixserverlib" ) @@ -135,14 +136,6 @@ type StateEntryList struct { StateEntries []StateEntry } -// A Transaction is something that can be committed or rolledback. -type Transaction interface { - // Commit the transaction - Commit() error - // Rollback the transaction. - Rollback() error -} - // A RoomRecentEventsUpdater is used to update the recent events in a room. // (On postgresql this wraps a database transaction that holds a "FOR UPDATE" // lock on the row in the rooms table holding the latest events for the room.) @@ -175,7 +168,7 @@ type RoomRecentEventsUpdater interface { // It will share the same transaction as this updater. MembershipUpdater(targetUserNID EventStateKeyNID) (MembershipUpdater, error) // Implements Transaction so it can be committed or rolledback - Transaction + common.Transaction } // A MembershipUpdater is used to update the membership of a user in a room. @@ -193,14 +186,14 @@ type MembershipUpdater interface { // Set the state to invite. // Returns whether this invite needs to be sent SetToInvite(event gomatrixserverlib.Event) (needsSending bool, err error) - // Set the state to join. + // Set the state to join or updates the event ID in the database. // Returns a list of invite event IDs that this state change retired. - SetToJoin(senderUserID string) (inviteEventIDs []string, err error) + SetToJoin(senderUserID string, eventID string, isUpdate bool) (inviteEventIDs []string, err error) // Set the state to leave. // Returns a list of invite event IDs that this state change retired. - SetToLeave(senderUserID string) (inviteEventIDs []string, err error) + SetToLeave(senderUserID string, eventID string) (inviteEventIDs []string, err error) // Implements Transaction so it can be committed or rolledback. - Transaction + common.Transaction } // A MissingEventError is an error that happened because the roomserver was diff --git a/src/github.com/matrix-org/dendrite/syncapi/storage/current_room_state_table.go b/src/github.com/matrix-org/dendrite/syncapi/storage/current_room_state_table.go index 9958e0d15..10933e965 100644 --- a/src/github.com/matrix-org/dendrite/syncapi/storage/current_room_state_table.go +++ b/src/github.com/matrix-org/dendrite/syncapi/storage/current_room_state_table.go @@ -18,6 +18,7 @@ import ( "database/sql" "github.com/lib/pq" + "github.com/matrix-org/dendrite/common" "github.com/matrix-org/gomatrixserverlib" ) @@ -136,7 +137,7 @@ func (s *currentRoomStateStatements) selectJoinedUsers() (map[string][]string, e // SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state. func (s *currentRoomStateStatements) selectRoomIDsWithMembership(txn *sql.Tx, userID, membership string) ([]string, error) { - rows, err := txn.Stmt(s.selectRoomIDsWithMembershipStmt).Query(userID, membership) + rows, err := common.TxStmt(txn, s.selectRoomIDsWithMembershipStmt).Query(userID, membership) if err != nil { return nil, err } @@ -155,7 +156,7 @@ func (s *currentRoomStateStatements) selectRoomIDsWithMembership(txn *sql.Tx, us // CurrentState returns all the current state events for the given room. func (s *currentRoomStateStatements) selectCurrentState(txn *sql.Tx, roomID string) ([]gomatrixserverlib.Event, error) { - rows, err := txn.Stmt(s.selectCurrentStateStmt).Query(roomID) + rows, err := common.TxStmt(txn, s.selectCurrentStateStmt).Query(roomID) if err != nil { return nil, err } @@ -165,21 +166,21 @@ func (s *currentRoomStateStatements) selectCurrentState(txn *sql.Tx, roomID stri } func (s *currentRoomStateStatements) deleteRoomStateByEventID(txn *sql.Tx, eventID string) error { - _, err := txn.Stmt(s.deleteRoomStateByEventIDStmt).Exec(eventID) + _, err := common.TxStmt(txn, s.deleteRoomStateByEventIDStmt).Exec(eventID) return err } func (s *currentRoomStateStatements) upsertRoomState( txn *sql.Tx, event gomatrixserverlib.Event, membership *string, addedAt int64, ) error { - _, err := txn.Stmt(s.upsertRoomStateStmt).Exec( + _, err := common.TxStmt(txn, s.upsertRoomStateStmt).Exec( event.RoomID(), event.EventID(), event.Type(), *event.StateKey(), event.JSON(), membership, addedAt, ) return err } func (s *currentRoomStateStatements) selectEventsWithEventIDs(txn *sql.Tx, eventIDs []string) ([]streamEvent, error) { - rows, err := txn.Stmt(s.selectEventsWithEventIDsStmt).Query(pq.StringArray(eventIDs)) + rows, err := common.TxStmt(txn, s.selectEventsWithEventIDsStmt).Query(pq.StringArray(eventIDs)) if err != nil { return nil, err } diff --git a/src/github.com/matrix-org/dendrite/syncapi/storage/output_room_events_table.go b/src/github.com/matrix-org/dendrite/syncapi/storage/output_room_events_table.go index f3c46298a..93774d1f1 100644 --- a/src/github.com/matrix-org/dendrite/syncapi/storage/output_room_events_table.go +++ b/src/github.com/matrix-org/dendrite/syncapi/storage/output_room_events_table.go @@ -19,6 +19,7 @@ import ( log "github.com/Sirupsen/logrus" "github.com/lib/pq" + "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" ) @@ -105,7 +106,7 @@ func (s *outputRoomEventsStatements) prepare(db *sql.DB) (err error) { func (s *outputRoomEventsStatements) selectStateInRange( txn *sql.Tx, oldPos, newPos types.StreamPosition, ) (map[string]map[string]bool, map[string]streamEvent, error) { - rows, err := txn.Stmt(s.selectStateInRangeStmt).Query(oldPos, newPos) + rows, err := common.TxStmt(txn, s.selectStateInRangeStmt).Query(oldPos, newPos) if err != nil { return nil, nil, err } @@ -167,12 +168,8 @@ func (s *outputRoomEventsStatements) selectStateInRange( // then this function should only ever be used at startup, as it will race with inserting events if it is // done afterwards. If there are no inserted events, 0 is returned. func (s *outputRoomEventsStatements) selectMaxID(txn *sql.Tx) (id int64, err error) { - stmt := s.selectMaxIDStmt - if txn != nil { - stmt = txn.Stmt(stmt) - } var nullableID sql.NullInt64 - err = stmt.QueryRow().Scan(&nullableID) + err = common.TxStmt(txn, s.selectMaxIDStmt).QueryRow().Scan(&nullableID) if nullableID.Valid { id = nullableID.Int64 } @@ -182,7 +179,7 @@ func (s *outputRoomEventsStatements) selectMaxID(txn *sql.Tx) (id int64, err err // InsertEvent into the output_room_events table. addState and removeState are an optional list of state event IDs. Returns the position // of the inserted event. func (s *outputRoomEventsStatements) insertEvent(txn *sql.Tx, event *gomatrixserverlib.Event, addState, removeState []string) (streamPos int64, err error) { - err = txn.Stmt(s.insertEventStmt).QueryRow( + err = common.TxStmt(txn, s.insertEventStmt).QueryRow( event.RoomID(), event.EventID(), event.JSON(), pq.StringArray(addState), pq.StringArray(removeState), ).Scan(&streamPos) return @@ -209,11 +206,7 @@ func (s *outputRoomEventsStatements) selectRecentEvents( // Events returns the events for the given event IDs. Returns an error if any one of the event IDs given are missing // from the database. func (s *outputRoomEventsStatements) selectEvents(txn *sql.Tx, eventIDs []string) ([]streamEvent, error) { - stmt := s.selectEventsStmt - if txn != nil { - stmt = txn.Stmt(stmt) - } - rows, err := stmt.Query(pq.StringArray(eventIDs)) + rows, err := common.TxStmt(txn, s.selectEventsStmt).Query(pq.StringArray(eventIDs)) if err != nil { return nil, err }