diff --git a/src/github.com/matrix-org/dendrite/clientapi/auth/authtypes/membership.go b/src/github.com/matrix-org/dendrite/clientapi/auth/authtypes/membership.go index ad5312db6..f00bd24da 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/auth/authtypes/membership.go +++ b/src/github.com/matrix-org/dendrite/clientapi/auth/authtypes/membership.go @@ -17,7 +17,8 @@ package authtypes // Membership represents the relationship between a user and a room they're a // member of type Membership struct { - Localpart string - RoomID string - EventID string + Localpart string + RoomID string + EventID string + StillInRoom bool } diff --git a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/membership_table.go b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/membership_table.go index 3c4e8d3af..9736f35e4 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/membership_table.go +++ b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/membership_table.go @@ -52,7 +52,7 @@ const insertMembershipSQL = ` ` const selectMembershipSQL = "" + - "SELECT * from memberships WHERE localpart = $1 AND room_id = $2" + "SELECT event_id, still_in_room from memberships WHERE localpart = $1 AND room_id = $2" const selectMembershipsByLocalpartSQL = "" + "SELECT room_id, event_id FROM memberships WHERE localpart = $1 AND still_in_room = true" @@ -69,6 +69,7 @@ const updateMembershipByEventIDSQL = "" + type membershipStatements struct { deleteMembershipsByEventIDsStmt *sql.Stmt insertMembershipStmt *sql.Stmt + selectMembershipStmt *sql.Stmt selectMembershipByEventIDStmt *sql.Stmt selectMembershipsByLocalpartStmt *sql.Stmt selectMembershipsByRoomIDStmt *sql.Stmt @@ -86,6 +87,9 @@ func (s *membershipStatements) prepare(db *sql.DB) (err error) { if s.insertMembershipStmt, err = db.Prepare(insertMembershipSQL); err != nil { return } + if s.selectMembershipStmt, err = db.Prepare(selectMembershipSQL); err != nil { + return + } if s.selectMembershipsByLocalpartStmt, err = db.Prepare(selectMembershipsByLocalpartSQL); err != nil { return } @@ -108,6 +112,18 @@ func (s *membershipStatements) deleteMembershipsByEventIDs(eventIDs []string, tx return } +func (s *membershipStatements) selectMembership(localpart string, roomID string) (*authtypes.Membership, error) { + m := authtypes.Membership{ + Localpart: localpart, + RoomID: roomID, + } + err := s.selectMembershipStmt.QueryRow(localpart, roomID).Scan(&m.EventID, &m.StillInRoom) + if err == sql.ErrNoRows { + return nil, nil + } + return &m, err +} + func (s *membershipStatements) selectMembershipsByLocalpart(localpart string) (memberships []authtypes.Membership, err error) { rows, err := s.selectMembershipsByLocalpartStmt.Query(localpart) if err != nil { diff --git a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/storage.go b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/storage.go index 3b72acc3d..d1fd12701 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/storage.go +++ b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/storage.go @@ -157,6 +157,13 @@ func (d *Database) UpdateMemberships(eventsToAdd []gomatrixserverlib.Event, idsT }) } +// GetMembership returns the membership for the given localpart and room ID +// If no membership match the given localpart, returns nil +// If there was an issue during the retrieval, returns the SQL error +func (d *Database) GetMembership(localpart string, roomID string) (*authtypes.Membership, error) { + return d.memberships.selectMembership(localpart, roomID) +} + // GetMembershipsByLocalpart returns an array containing the memberships for all // the rooms a user matching a given localpart is a member of // If no membership match the given localpart, returns an empty array