diff --git a/src/github.com/matrix-org/dendrite/roomserver/storage/invite_table.go b/src/github.com/matrix-org/dendrite/roomserver/storage/invite_table.go index c41f9eb11..88c84febe 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/storage/invite_table.go +++ b/src/github.com/matrix-org/dendrite/roomserver/storage/invite_table.go @@ -43,36 +43,18 @@ CREATE TABLE invites ( -- This is set implicitly when processing KIND_NEW events and explicitly -- when rejecting events over federation. retired BOOLEAN NOT NULL DEFAULT FALSE, - -- Whether the invite has been sent to the output stream. - -- We maintain a separate output stream of invite events since they don't - -- always occur within a room we have state for. - sent_invite_to_output BOOLEAN NOT NULL DEFAULT FALSE, - -- Whether the retirement has been sent to the output stream. - sent_retired_to_output BOOLEAN NOT NULL DEFAULT FALSE, -- The invite event JSON. invite_event_json TEXT NOT NULL ); CREATE INDEX invites_active_idx ON invites (target_state_key_nid, room_nid) WHERE NOT retired; - -CREATE INDEX invites_unsent_retired_idx ON invites (target_state_key_nid, room_nid) - WHERE retired AND NOT sent_retired_to_output; ` - const insertInviteEventSQL = "" + "INSERT INTO invites (invite_event_id, room_nid, target_state_key_nid," + " sender_state_key_nid, invite_event_json) VALUES ($1, $2, $3, $4, $5)" + " ON CONFLICT DO NOTHING" -const selectInviteSQL = "" + - "SELECT retired, sent_invite_to_output FROM invites" + - " WHERE invite_event_id = $1" - -const updateInviteSentInviteToOutputSQL = "" + - "UPDATE invites SET sent_invite_to_output = TRUE" + - " WHERE invite_event_id = $1" - const selectInviteActiveForUserInRoomSQL = "" + "SELECT invite_event_id, sender_state_key_nid FROM invites" + " WHERE target_state_key_id = $1 AND room_nid = $2" + @@ -84,26 +66,14 @@ const selectInviteActiveForUserInRoomSQL = "" + // However the matrix protocol doesn't give us a way to reliably identify the // invites that were retired, so we are forced to retire all of them. const updateInviteRetiredSQL = "" + - "UPDATE invites SET retired_by_event_nid = TRUE" + - " WHERE room_nid = $1 AND target_state_key_nid = $2 AND NOT retired" - -const selectInviteUnsentRetiredSQL = "" + - "SELECT invite_event_id FROM invites" + - " WHERE target_state_key_id = $1 AND room_nid = $2" + - " AND retired AND NOT sent_retired_to_output" - -const updateInviteSentRetiredToOutputSQL = "" + - "UPDATE invites SET sent_retired_to_output = TRUE" + - " WHERE invite_event_id = $1" + "UPDATE invites SET retired = TRUE" + + " WHERE room_nid = $1 AND target_state_key_nid = $2 AND NOT retired" + + " RETURNING invite_event_id" type inviteStatements struct { insertInviteEventStmt *sql.Stmt - selectInviteStmt *sql.Stmt selectInviteActiveForUserInRoomStmt *sql.Stmt updateInviteRetiredStmt *sql.Stmt - selectInviteUnsentRetiredStmt *sql.Stmt - updateInviteSentInviteToOutputStmt *sql.Stmt - updateInviteSentRetiredToOutputStmt *sql.Stmt } func (s *inviteStatements) prepare(db *sql.DB) (err error) { @@ -114,40 +84,46 @@ func (s *inviteStatements) prepare(db *sql.DB) (err error) { return statementList{ {&s.insertInviteEventStmt, insertInviteEventSQL}, - {&s.selectInviteStmt, selectInviteSQL}, - {&s.updateInviteSentInviteToOutputStmt, updateInviteSentInviteToOutputSQL}, {&s.selectInviteActiveForUserInRoomStmt, selectInviteActiveForUserInRoomSQL}, {&s.updateInviteRetiredStmt, updateInviteRetiredSQL}, - - {&s.updateInviteSentRetiredToOutputStmt, updateInviteSentRetiredToOutputSQL}, }.prepare(db) } func (s *inviteStatements) insertInviteEvent( - txn *sql.Tx, inviteEventNID types.EventNID, roomNID types.RoomNID, + txn *sql.Tx, inviteEventID string, roomNID types.RoomNID, targetNID, senderNID types.EventStateKeyNID, inviteEventJSON []byte, -) error { - _, err := txn.Stmt(s.insertInviteEventStmt).Exec( - inviteEventNID, roomNID, targetNID, senderNID, inviteEventJSON, +) (bool, error) { + result, err := txn.Stmt(s.insertInviteEventStmt).Exec( + inviteEventID, roomNID, targetNID, senderNID, inviteEventJSON, ) - return err + if err != nil { + return false, err + } + count, err := result.RowsAffected() + if err != nil { + return false, err + } + return count != 0, nil } func (s *inviteStatements) updateInviteRetired( txn *sql.Tx, roomNID types.RoomNID, targetNID types.EventStateKeyNID, -) error { - _, err := txn.Stmt(s.updateInviteRetiredStmt).Exec(roomNID, targetNID) - return err -} - -func (s *inviteStatements) selectInvite( - txn *sql.Tx, inviteEventNID types.EventNID, -) (RetiredByNID types.EventNID, sentInviteToOutput, sentRetiredToOutput bool, err error) { - err = txn.Stmt(s.selectInviteStmt).QueryRow(inviteEventNID).Scan( - &RetiredByNID, &sentInviteToOutput, &sentRetiredToOutput, - ) - return +) ([]string, error) { + rows, err := txn.Stmt(s.updateInviteRetiredStmt).Query(roomNID, targetNID) + if err != nil { + return nil, err + } + defer rows.Close() + var result []string + for rows.Next() { + var inviteEventID string + if err := rows.Scan(&inviteEventID); err != nil { + return nil, err + } + result = append(result, inviteEventID) + } + return result, nil } // selectInviteActiveForUserInRoom returns a list of sender state key NIDs @@ -171,17 +147,3 @@ func (s *inviteStatements) selectInviteActiveForUserInRoom( } return result, nil } - -func (s *inviteStatements) updateInviteSentInviteToOutput( - inviteEventNID types.EventNID, -) error { - _, err := s.updateInviteSentInviteToOutputStmt.Exec(inviteEventNID) - return err -} - -func (s *inviteStatements) updateInviteSentRetiredToOutput( - inviteEventNID types.EventNID, -) error { - _, err := s.updateInviteSentRetiredToOutputStmt.Exec(inviteEventNID) - return err -} diff --git a/src/github.com/matrix-org/dendrite/roomserver/storage/membership_table.go b/src/github.com/matrix-org/dendrite/roomserver/storage/membership_table.go index eada1a925..eb1ae3383 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/storage/membership_table.go +++ b/src/github.com/matrix-org/dendrite/roomserver/storage/membership_table.go @@ -95,7 +95,7 @@ func (s *membershipStatements) selectMembershipForUpdate( return } -func (s *membershipStatements) updateMembershipSQL( +func (s *membershipStatements) updateMembership( txn *sql.Tx, roomNID types.RoomNID, targetNID types.EventStateKeyNID, senderNID types.EventStateKeyNID, membership membershipState, ) error { diff --git a/src/github.com/matrix-org/dendrite/roomserver/storage/sql.go b/src/github.com/matrix-org/dendrite/roomserver/storage/sql.go index 37d43e024..fca7965b4 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/storage/sql.go +++ b/src/github.com/matrix-org/dendrite/roomserver/storage/sql.go @@ -31,6 +31,7 @@ type statements struct { stateBlockStatements previousEventStatements inviteStatements + membershipStatements } func (s *statements) prepare(db *sql.DB) error { @@ -47,6 +48,7 @@ func (s *statements) prepare(db *sql.DB) error { s.stateBlockStatements.prepare, s.previousEventStatements.prepare, s.inviteStatements.prepare, + s.membershipStatements.prepare, } { if err = prepare(db); err != nil { return err diff --git a/src/github.com/matrix-org/dendrite/roomserver/storage/storage.go b/src/github.com/matrix-org/dendrite/roomserver/storage/storage.go index 22a33b008..01ab12875 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/storage/storage.go +++ b/src/github.com/matrix-org/dendrite/roomserver/storage/storage.go @@ -16,7 +16,6 @@ package storage import ( "database/sql" - "fmt" // Import the postgres database driver. _ "github.com/lib/pq" @@ -262,12 +261,15 @@ func (d *Database) GetLatestEventsForUpdate(roomNID types.RoomNID) (types.RoomRe return nil, err } } - return &roomRecentEventsUpdater{txn, d, stateAndRefs, lastEventIDSent, currentStateSnapshotNID}, nil + return &roomRecentEventsUpdater{ + transaction{txn}, d, roomNID, stateAndRefs, lastEventIDSent, currentStateSnapshotNID, + }, nil } type roomRecentEventsUpdater struct { - txn *sql.Tx + transaction d *Database + roomNID types.RoomNID latestEvents []types.StateAtEventAndReference lastEventIDSent string currentStateSnapshotNID types.StateSnapshotNID @@ -332,18 +334,8 @@ func (u *roomRecentEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error return u.d.statements.updateEventSentToOutput(u.txn, eventNID) } -// Commit implements types.RoomRecentEventsUpdater -func (u *roomRecentEventsUpdater) Commit() error { - return u.txn.Commit() -} - -// Rollback implements types.RoomRecentEventsUpdater -func (u *roomRecentEventsUpdater) Rollback() error { - return u.txn.Rollback() -} - func (u *roomRecentEventsUpdater) MembershipUpdater(targetNID types.EventStateKeyNID) (types.MembershipUpdater, error) { - panic(fmt.Errorf("Not implemented")) + return u.d.membershipUpdaterTxn(u.txn, u.roomNID, targetNID) } // RoomNID implements query.RoomserverQueryAPIDB @@ -378,3 +370,122 @@ func (d *Database) StateEntriesForTuples( ) ([]types.StateEntryList, error) { return d.statements.bulkSelectFilteredStateBlockEntries(stateBlockNIDs, stateKeyTuples) } + +type membershipUpdater struct { + transaction + d *Database + roomNID types.RoomNID + targetNID types.EventStateKeyNID + membership membershipState +} + +func (d *Database) membershipUpdaterTxn( + txn *sql.Tx, roomNID types.RoomNID, targetNID types.EventStateKeyNID, +) (types.MembershipUpdater, error) { + + if err := d.statements.insertMembership(txn, roomNID, targetNID); err != nil { + return nil, err + } + + membership, err := d.statements.selectMembershipForUpdate(txn, roomNID, targetNID) + if err != nil { + return nil, err + } + + return &membershipUpdater{ + transaction{txn}, d, roomNID, targetNID, membership, + }, nil +} + +// IsInvite implements types.MembershipUpdater +func (u *membershipUpdater) IsInvite() bool { + return u.membership == membershipStateInvite +} + +// IsJoin implements types.MembershipUpdater +func (u *membershipUpdater) IsJoin() bool { + return u.membership == membershipStateJoin +} + +// SetToInvite implements types.MembershipUpdater +func (u *membershipUpdater) SetToInvite(event gomatrixserverlib.Event) (bool, error) { + // TODO: assign the state key inside the transaction. + senderNID, err := u.d.assignStateKeyNID(event.Sender()) + if err != nil { + return false, err + } + eventID := event.EventID() + inserted, err := u.d.statements.insertInviteEvent( + u.txn, eventID, u.roomNID, u.targetNID, senderNID, event.JSON(), + ) + if err != nil { + return false, err + } + if u.membership != membershipStateInvite { + if err = u.d.statements.updateMembership( + u.txn, u.roomNID, u.targetNID, senderNID, membershipStateInvite, + ); err != nil { + return false, err + } + } + return inserted, nil +} + +// SetToJoin implements types.MembershipUpdater +func (u *membershipUpdater) SetToJoin(senderID string) ([]string, error) { + // TODO: assign the state key inside the transaction. + senderNID, err := u.d.assignStateKeyNID(senderID) + if err != nil { + return nil, err + } + inviteEventIDs, err := u.d.statements.updateInviteRetired( + u.txn, u.roomNID, u.targetNID, + ) + if err != nil { + return nil, err + } + if u.membership != membershipStateJoin { + if err = u.d.statements.updateMembership( + u.txn, u.roomNID, u.targetNID, senderNID, membershipStateJoin, + ); err != nil { + return nil, err + } + } + return inviteEventIDs, nil +} + +func (u *membershipUpdater) SetToLeave(senderID string) ([]string, error) { + // TODO: assign the state key inside the transaction. + senderNID, err := u.d.assignStateKeyNID(senderID) + if err != nil { + return nil, err + } + inviteEventIDs, err := u.d.statements.updateInviteRetired( + u.txn, u.roomNID, u.targetNID, + ) + if err != nil { + return nil, err + } + if u.membership != membershipStateLeaveOrBan { + if err = u.d.statements.updateMembership( + u.txn, u.roomNID, u.targetNID, senderNID, membershipStateLeaveOrBan, + ); err != nil { + return nil, err + } + } + return inviteEventIDs, nil +} + +type transaction struct { + txn *sql.Tx +} + +// Commit implements types.Transaction +func (t *transaction) Commit() error { + return t.txn.Commit() +} + +// Rollback implements types.Transaction +func (t *transaction) Rollback() error { + return t.txn.Rollback() +} diff --git a/src/github.com/matrix-org/dendrite/roomserver/types/types.go b/src/github.com/matrix-org/dendrite/roomserver/types/types.go index c01d18a09..ba12fe492 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/types/types.go +++ b/src/github.com/matrix-org/dendrite/roomserver/types/types.go @@ -188,16 +188,12 @@ type MembershipUpdater interface { IsJoin() bool // Set the state to invite. // Returns whether this invite needs to be sent - SetToInviteFrom(senderID string, event gomatrixserverlib.Event) (needsSending bool, err error) + SetToInvite(event gomatrixserverlib.Event) (needsSending bool, err error) // Set the state to join. - SetToJoinFrom(senderID string) (inviteIDs []string, err error) + SetToJoin(senderID string) (inviteIDs []string, err error) // Set the state to leave. // Returns a list of invite event IDs that this state change retired. - SetToLeaveFrom(senderID string) (inviteIDs []string, err error) - // Mark the invite as sent. - MarkInviteAsSent(inviteID string) error - // Mark the invite retirement as sent. - MarkInviteRetirementAsSent(inviteIDs []string) error + SetToLeave(senderID string) (inviteIDs []string, err error) // Implements Transaction so it can be committed or rolledback. Transaction }