From eebef093dc51a648b6421350d02a4159ab49f11f Mon Sep 17 00:00:00 2001 From: Kegan Dougal Date: Wed, 27 May 2020 10:16:33 +0100 Subject: [PATCH] Convert membership table --- .../storage/postgres/membership_table.go | 38 +++++------ roomserver/storage/postgres/sql.go | 49 --------------- roomserver/storage/postgres/storage.go | 50 ++++++++------- roomserver/storage/shared/storage.go | 1 + .../storage/sqlite3/membership_table.go | 55 +++++++--------- roomserver/storage/sqlite3/sql.go | 49 --------------- roomserver/storage/sqlite3/storage.go | 63 +++++++++---------- roomserver/storage/tables/interface.go | 17 +++++ 8 files changed, 115 insertions(+), 207 deletions(-) delete mode 100644 roomserver/storage/postgres/sql.go delete mode 100644 roomserver/storage/sqlite3/sql.go diff --git a/roomserver/storage/postgres/membership_table.go b/roomserver/storage/postgres/membership_table.go index f290a05f9..9f0d97ccf 100644 --- a/roomserver/storage/postgres/membership_table.go +++ b/roomserver/storage/postgres/membership_table.go @@ -20,17 +20,10 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) -type membershipState int64 - -const ( - membershipStateLeaveOrBan membershipState = 1 - membershipStateInvite membershipState = 2 - membershipStateJoin membershipState = 3 -) - const membershipSchema = ` -- The membership table is used to coordinate updates between the invite table -- and the room state tables. @@ -115,13 +108,14 @@ type membershipStatements struct { updateMembershipStmt *sql.Stmt } -func (s *membershipStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(membershipSchema) +func NewPostgresMembershipTable(db *sql.DB) (tables.Membership, error) { + s := &membershipStatements{} + _, err := db.Exec(membershipSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, statementList{ {&s.insertMembershipStmt, insertMembershipSQL}, {&s.selectMembershipForUpdateStmt, selectMembershipForUpdateSQL}, {&s.selectMembershipFromRoomAndTargetStmt, selectMembershipFromRoomAndTargetSQL}, @@ -133,7 +127,7 @@ func (s *membershipStatements) prepare(db *sql.DB) (err error) { }.prepare(db) } -func (s *membershipStatements) insertMembership( +func (s *membershipStatements) InsertMembership( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, localTarget bool, @@ -143,27 +137,27 @@ func (s *membershipStatements) insertMembership( return err } -func (s *membershipStatements) selectMembershipForUpdate( +func (s *membershipStatements) SelectMembershipForUpdate( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, -) (membership membershipState, err error) { +) (membership tables.MembershipState, err error) { err = internal.TxStmt(txn, s.selectMembershipForUpdateStmt).QueryRowContext( ctx, roomNID, targetUserNID, ).Scan(&membership) return } -func (s *membershipStatements) selectMembershipFromRoomAndTarget( +func (s *membershipStatements) SelectMembershipFromRoomAndTarget( ctx context.Context, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, -) (eventNID types.EventNID, membership membershipState, err error) { +) (eventNID types.EventNID, membership tables.MembershipState, err error) { err = s.selectMembershipFromRoomAndTargetStmt.QueryRowContext( ctx, roomNID, targetUserNID, ).Scan(&membership, &eventNID) return } -func (s *membershipStatements) selectMembershipsFromRoom( +func (s *membershipStatements) SelectMembershipsFromRoom( ctx context.Context, roomNID types.RoomNID, localOnly bool, ) (eventNIDs []types.EventNID, err error) { var stmt *sql.Stmt @@ -188,9 +182,9 @@ func (s *membershipStatements) selectMembershipsFromRoom( return eventNIDs, rows.Err() } -func (s *membershipStatements) selectMembershipsFromRoomAndMembership( +func (s *membershipStatements) SelectMembershipsFromRoomAndMembership( ctx context.Context, - roomNID types.RoomNID, membership membershipState, localOnly bool, + roomNID types.RoomNID, membership tables.MembershipState, localOnly bool, ) (eventNIDs []types.EventNID, err error) { var rows *sql.Rows var stmt *sql.Stmt @@ -215,10 +209,10 @@ func (s *membershipStatements) selectMembershipsFromRoomAndMembership( return eventNIDs, rows.Err() } -func (s *membershipStatements) updateMembership( +func (s *membershipStatements) UpdateMembership( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, - senderUserNID types.EventStateKeyNID, membership membershipState, + senderUserNID types.EventStateKeyNID, membership tables.MembershipState, eventNID types.EventNID, ) error { _, err := internal.TxStmt(txn, s.updateMembershipStmt).ExecContext( diff --git a/roomserver/storage/postgres/sql.go b/roomserver/storage/postgres/sql.go deleted file mode 100644 index 1a84508ae..000000000 --- a/roomserver/storage/postgres/sql.go +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2017-2018 New Vector Ltd -// Copyright 2019-2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package postgres - -import ( - "database/sql" -) - -type statements struct { - eventTypeStatements - eventStateKeyStatements - roomStatements - eventStatements - eventJSONStatements - stateSnapshotStatements - stateBlockStatements - previousEventStatements - roomAliasesStatements - inviteStatements - membershipStatements - transactionStatements -} - -func (s *statements) prepare(db *sql.DB) error { - var err error - - for _, prepare := range []func(db *sql.DB) error{ - s.membershipStatements.prepare, - } { - if err = prepare(db); err != nil { - return err - } - } - - return nil -} diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go index 992fc1dfb..521c841dd 100644 --- a/roomserver/storage/postgres/storage.go +++ b/roomserver/storage/postgres/storage.go @@ -33,7 +33,6 @@ import ( // A Database is used to store room events and stream offsets. type Database struct { shared.Database - statements statements events tables.Events eventTypes tables.EventTypes eventStateKeys tables.EventStateKeys @@ -42,6 +41,7 @@ type Database struct { transactions tables.Transactions prevEvents tables.PreviousEvents invites tables.Invites + membership tables.Membership db *sql.DB } @@ -53,9 +53,6 @@ func Open(dataSourceName string, dbProperties internal.DbProperties) (*Database, if d.db, err = sqlutil.Open("postgres", dataSourceName, dbProperties); err != nil { return nil, err } - if err = d.statements.prepare(d.db); err != nil { - return nil, err - } d.eventStateKeys, err = NewPostgresEventStateKeysTable(d.db) if err != nil { return nil, err @@ -100,6 +97,10 @@ func Open(dataSourceName string, dbProperties internal.DbProperties) (*Database, if err != nil { return nil, err } + d.membership, err = NewPostgresMembershipTable(d.db) + if err != nil { + return nil, err + } d.Database = shared.Database{ DB: d.db, EventTypesTable: d.eventTypes, @@ -113,6 +114,7 @@ func Open(dataSourceName string, dbProperties internal.DbProperties) (*Database, PrevEventsTable: d.prevEvents, RoomAliasesTable: roomAliases, InvitesTable: d.invites, + MembershipTable: d.membership, } return &d, nil } @@ -300,7 +302,7 @@ type membershipUpdater struct { d *Database roomNID types.RoomNID targetUserNID types.EventStateKeyNID - membership membershipState + membership tables.MembershipState } func (d *Database) membershipUpdaterTxn( @@ -311,11 +313,11 @@ func (d *Database) membershipUpdaterTxn( targetLocal bool, ) (types.MembershipUpdater, error) { - if err := d.statements.insertMembership(ctx, txn, roomNID, targetUserNID, targetLocal); err != nil { + if err := d.membership.InsertMembership(ctx, txn, roomNID, targetUserNID, targetLocal); err != nil { return nil, err } - membership, err := d.statements.selectMembershipForUpdate(ctx, txn, roomNID, targetUserNID) + membership, err := d.membership.SelectMembershipForUpdate(ctx, txn, roomNID, targetUserNID) if err != nil { return nil, err } @@ -327,17 +329,17 @@ func (d *Database) membershipUpdaterTxn( // IsInvite implements types.MembershipUpdater func (u *membershipUpdater) IsInvite() bool { - return u.membership == membershipStateInvite + return u.membership == tables.MembershipStateInvite } // IsJoin implements types.MembershipUpdater func (u *membershipUpdater) IsJoin() bool { - return u.membership == membershipStateJoin + return u.membership == tables.MembershipStateJoin } // IsLeave implements types.MembershipUpdater func (u *membershipUpdater) IsLeave() bool { - return u.membership == membershipStateLeaveOrBan + return u.membership == tables.MembershipStateLeaveOrBan } // SetToInvite implements types.MembershipUpdater @@ -352,9 +354,9 @@ func (u *membershipUpdater) SetToInvite(event gomatrixserverlib.Event) (bool, er if err != nil { return false, err } - if u.membership != membershipStateInvite { - if err = u.d.statements.updateMembership( - u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, membershipStateInvite, 0, + if u.membership != tables.MembershipStateInvite { + if err = u.d.membership.UpdateMembership( + u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateInvite, 0, ); err != nil { return false, err } @@ -387,10 +389,10 @@ func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd return nil, err } - if u.membership != membershipStateJoin || isUpdate { - if err = u.d.statements.updateMembership( + if u.membership != tables.MembershipStateJoin || isUpdate { + if err = u.d.membership.UpdateMembership( u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, - membershipStateJoin, nIDs[eventID], + tables.MembershipStateJoin, nIDs[eventID], ); err != nil { return nil, err } @@ -418,10 +420,10 @@ func (u *membershipUpdater) SetToLeave(senderUserID string, eventID string) ([]s return nil, err } - if u.membership != membershipStateLeaveOrBan { - if err = u.d.statements.updateMembership( + if u.membership != tables.MembershipStateLeaveOrBan { + if err = u.d.membership.UpdateMembership( u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, - membershipStateLeaveOrBan, nIDs[eventID], + tables.MembershipStateLeaveOrBan, nIDs[eventID], ); err != nil { return nil, err } @@ -439,7 +441,7 @@ func (d *Database) GetMembership( } senderMembershipEventNID, senderMembership, err := - d.statements.selectMembershipFromRoomAndTarget( + d.membership.SelectMembershipFromRoomAndTarget( ctx, roomNID, requestSenderUserNID, ) if err == sql.ErrNoRows { @@ -449,7 +451,7 @@ func (d *Database) GetMembership( return } - return senderMembershipEventNID, senderMembership == membershipStateJoin, nil + return senderMembershipEventNID, senderMembership == tables.MembershipStateJoin, nil } // GetMembershipEventNIDsForRoom implements query.RoomserverQueryAPIDB @@ -457,12 +459,12 @@ func (d *Database) GetMembershipEventNIDsForRoom( ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool, ) ([]types.EventNID, error) { if joinOnly { - return d.statements.selectMembershipsFromRoomAndMembership( - ctx, roomNID, membershipStateJoin, localOnly, + return d.membership.SelectMembershipsFromRoomAndMembership( + ctx, roomNID, tables.MembershipStateJoin, localOnly, ) } - return d.statements.selectMembershipsFromRoom(ctx, roomNID, localOnly) + return d.membership.SelectMembershipsFromRoom(ctx, roomNID, localOnly) } type transaction struct { diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 311afbeb8..f6068cda3 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -25,6 +25,7 @@ type Database struct { RoomAliasesTable tables.RoomAliases PrevEventsTable tables.PreviousEvents InvitesTable tables.Invites + MembershipTable tables.Membership } // EventTypeNIDs implements state.RoomStateDatabase diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go index 34108af43..8d7693586 100644 --- a/roomserver/storage/sqlite3/membership_table.go +++ b/roomserver/storage/sqlite3/membership_table.go @@ -20,17 +20,10 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) -type membershipState int64 - -const ( - membershipStateLeaveOrBan membershipState = 1 - membershipStateInvite membershipState = 2 - membershipStateJoin membershipState = 3 -) - const membershipSchema = ` CREATE TABLE IF NOT EXISTS roomserver_membership ( room_nid INTEGER NOT NULL, @@ -91,13 +84,14 @@ type membershipStatements struct { updateMembershipStmt *sql.Stmt } -func (s *membershipStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(membershipSchema) +func NewSqliteMembershipTable(db *sql.DB) (tables.Membership, error) { + s := &membershipStatements{} + _, err := db.Exec(membershipSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, statementList{ {&s.insertMembershipStmt, insertMembershipSQL}, {&s.selectMembershipForUpdateStmt, selectMembershipForUpdateSQL}, {&s.selectMembershipFromRoomAndTargetStmt, selectMembershipFromRoomAndTargetSQL}, @@ -109,7 +103,7 @@ func (s *membershipStatements) prepare(db *sql.DB) (err error) { }.prepare(db) } -func (s *membershipStatements) insertMembership( +func (s *membershipStatements) InsertMembership( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, localTarget bool, @@ -119,10 +113,10 @@ func (s *membershipStatements) insertMembership( return err } -func (s *membershipStatements) selectMembershipForUpdate( +func (s *membershipStatements) SelectMembershipForUpdate( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, -) (membership membershipState, err error) { +) (membership tables.MembershipState, err error) { stmt := internal.TxStmt(txn, s.selectMembershipForUpdateStmt) err = stmt.QueryRowContext( ctx, roomNID, targetUserNID, @@ -130,26 +124,25 @@ func (s *membershipStatements) selectMembershipForUpdate( return } -func (s *membershipStatements) selectMembershipFromRoomAndTarget( - ctx context.Context, txn *sql.Tx, +func (s *membershipStatements) SelectMembershipFromRoomAndTarget( + ctx context.Context, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, -) (eventNID types.EventNID, membership membershipState, err error) { - selectStmt := internal.TxStmt(txn, s.selectMembershipFromRoomAndTargetStmt) - err = selectStmt.QueryRowContext( +) (eventNID types.EventNID, membership tables.MembershipState, err error) { + err = s.selectMembershipFromRoomAndTargetStmt.QueryRowContext( ctx, roomNID, targetUserNID, ).Scan(&membership, &eventNID) return } -func (s *membershipStatements) selectMembershipsFromRoom( - ctx context.Context, txn *sql.Tx, +func (s *membershipStatements) SelectMembershipsFromRoom( + ctx context.Context, roomNID types.RoomNID, localOnly bool, ) (eventNIDs []types.EventNID, err error) { var selectStmt *sql.Stmt if localOnly { - selectStmt = internal.TxStmt(txn, s.selectLocalMembershipsFromRoomStmt) + selectStmt = s.selectLocalMembershipsFromRoomStmt } else { - selectStmt = internal.TxStmt(txn, s.selectMembershipsFromRoomStmt) + selectStmt = s.selectMembershipsFromRoomStmt } rows, err := selectStmt.QueryContext(ctx, roomNID) if err != nil { @@ -167,15 +160,15 @@ func (s *membershipStatements) selectMembershipsFromRoom( return } -func (s *membershipStatements) selectMembershipsFromRoomAndMembership( - ctx context.Context, txn *sql.Tx, - roomNID types.RoomNID, membership membershipState, localOnly bool, +func (s *membershipStatements) SelectMembershipsFromRoomAndMembership( + ctx context.Context, + roomNID types.RoomNID, membership tables.MembershipState, localOnly bool, ) (eventNIDs []types.EventNID, err error) { var stmt *sql.Stmt if localOnly { - stmt = internal.TxStmt(txn, s.selectLocalMembershipsFromRoomAndMembershipStmt) + stmt = s.selectLocalMembershipsFromRoomAndMembershipStmt } else { - stmt = internal.TxStmt(txn, s.selectMembershipsFromRoomAndMembershipStmt) + stmt = s.selectMembershipsFromRoomAndMembershipStmt } rows, err := stmt.QueryContext(ctx, roomNID, membership) if err != nil { @@ -193,10 +186,10 @@ func (s *membershipStatements) selectMembershipsFromRoomAndMembership( return } -func (s *membershipStatements) updateMembership( +func (s *membershipStatements) UpdateMembership( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, - senderUserNID types.EventStateKeyNID, membership membershipState, + senderUserNID types.EventStateKeyNID, membership tables.MembershipState, eventNID types.EventNID, ) error { stmt := internal.TxStmt(txn, s.updateMembershipStmt) diff --git a/roomserver/storage/sqlite3/sql.go b/roomserver/storage/sqlite3/sql.go deleted file mode 100644 index e07fc6465..000000000 --- a/roomserver/storage/sqlite3/sql.go +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2017-2018 New Vector Ltd -// Copyright 2019-2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sqlite3 - -import ( - "database/sql" -) - -type statements struct { - eventTypeStatements - eventStateKeyStatements - roomStatements - eventStatements - eventJSONStatements - stateSnapshotStatements - stateBlockStatements - previousEventStatements - roomAliasesStatements - inviteStatements - membershipStatements - transactionStatements -} - -func (s *statements) prepare(db *sql.DB) error { - var err error - - for _, prepare := range []func(db *sql.DB) error{ - s.membershipStatements.prepare, - } { - if err = prepare(db); err != nil { - return err - } - } - - return nil -} diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index 0b5a7469e..5803a6d87 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -34,7 +34,6 @@ import ( // A Database is used to store room events and stream offsets. type Database struct { shared.Database - statements statements events tables.Events eventJSON tables.EventJSON eventTypes tables.EventTypes @@ -43,6 +42,7 @@ type Database struct { transactions tables.Transactions prevEvents tables.PreviousEvents invites tables.Invites + membership tables.Membership db *sql.DB } @@ -73,9 +73,7 @@ func Open(dataSourceName string) (*Database, error) { // acquire the global mutex and never unlock it because it is waiting for a connection // which it will never obtain. d.db.SetMaxOpenConns(20) - if err = d.statements.prepare(d.db); err != nil { - return nil, err - } + d.eventStateKeys, err = NewSqliteEventStateKeysTable(d.db) if err != nil { return nil, err @@ -120,6 +118,10 @@ func Open(dataSourceName string) (*Database, error) { if err != nil { return nil, err } + d.membership, err = NewSqliteMembershipTable(d.db) + if err != nil { + return nil, err + } d.Database = shared.Database{ DB: d.db, EventsTable: d.events, @@ -133,6 +135,7 @@ func Open(dataSourceName string) (*Database, error) { PrevEventsTable: d.prevEvents, RoomAliasesTable: roomAliases, InvitesTable: d.invites, + MembershipTable: d.membership, } return &d, nil } @@ -364,7 +367,7 @@ type membershipUpdater struct { d *Database roomNID types.RoomNID targetUserNID types.EventStateKeyNID - membership membershipState + membership tables.MembershipState } func (d *Database) membershipUpdaterTxn( @@ -375,11 +378,11 @@ func (d *Database) membershipUpdaterTxn( targetLocal bool, ) (types.MembershipUpdater, error) { - if err := d.statements.insertMembership(ctx, txn, roomNID, targetUserNID, targetLocal); err != nil { + if err := d.membership.InsertMembership(ctx, txn, roomNID, targetUserNID, targetLocal); err != nil { return nil, err } - membership, err := d.statements.selectMembershipForUpdate(ctx, txn, roomNID, targetUserNID) + membership, err := d.membership.SelectMembershipForUpdate(ctx, txn, roomNID, targetUserNID) if err != nil { return nil, err } @@ -392,17 +395,17 @@ func (d *Database) membershipUpdaterTxn( // IsInvite implements types.MembershipUpdater func (u *membershipUpdater) IsInvite() bool { - return u.membership == membershipStateInvite + return u.membership == tables.MembershipStateInvite } // IsJoin implements types.MembershipUpdater func (u *membershipUpdater) IsJoin() bool { - return u.membership == membershipStateJoin + return u.membership == tables.MembershipStateJoin } // IsLeave implements types.MembershipUpdater func (u *membershipUpdater) IsLeave() bool { - return u.membership == membershipStateLeaveOrBan + return u.membership == tables.MembershipStateLeaveOrBan } // SetToInvite implements types.MembershipUpdater @@ -418,9 +421,9 @@ func (u *membershipUpdater) SetToInvite(event gomatrixserverlib.Event) (inserted if err != nil { return err } - if u.membership != membershipStateInvite { - if err = u.d.statements.updateMembership( - u.ctx, txn, u.roomNID, u.targetUserNID, senderUserNID, membershipStateInvite, 0, + if u.membership != tables.MembershipStateInvite { + if err = u.d.membership.UpdateMembership( + u.ctx, txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateInvite, 0, ); err != nil { return err } @@ -454,10 +457,10 @@ func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd return err } - if u.membership != membershipStateJoin || isUpdate { - if err = u.d.statements.updateMembership( + if u.membership != tables.MembershipStateJoin || isUpdate { + if err = u.d.membership.UpdateMembership( u.ctx, txn, u.roomNID, u.targetUserNID, senderUserNID, - membershipStateJoin, nIDs[eventID], + tables.MembershipStateJoin, nIDs[eventID], ); err != nil { return err } @@ -488,10 +491,10 @@ func (u *membershipUpdater) SetToLeave(senderUserID string, eventID string) (inv return err } - if u.membership != membershipStateLeaveOrBan { - if err = u.d.statements.updateMembership( + if u.membership != tables.MembershipStateLeaveOrBan { + if err = u.d.membership.UpdateMembership( u.ctx, txn, u.roomNID, u.targetUserNID, senderUserNID, - membershipStateLeaveOrBan, nIDs[eventID], + tables.MembershipStateLeaveOrBan, nIDs[eventID], ); err != nil { return err } @@ -512,8 +515,8 @@ func (d *Database) GetMembership( } membershipEventNID, _, err = - d.statements.selectMembershipFromRoomAndTarget( - ctx, txn, roomNID, requestSenderUserNID, + d.membership.SelectMembershipFromRoomAndTarget( + ctx, roomNID, requestSenderUserNID, ) if err == sql.ErrNoRows { // The user has never been a member of that room @@ -533,18 +536,14 @@ func (d *Database) GetMembership( func (d *Database) GetMembershipEventNIDsForRoom( ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool, ) (eventNIDs []types.EventNID, err error) { - err = internal.WithTransaction(d.db, func(txn *sql.Tx) error { - if joinOnly { - eventNIDs, err = d.statements.selectMembershipsFromRoomAndMembership( - ctx, txn, roomNID, membershipStateJoin, localOnly, - ) - return nil - } + if joinOnly { + eventNIDs, err = d.membership.SelectMembershipsFromRoomAndMembership( + ctx, roomNID, tables.MembershipStateJoin, localOnly, + ) + return + } - eventNIDs, err = d.statements.selectMembershipsFromRoom(ctx, txn, roomNID, localOnly) - return nil - }) - return + return d.membership.SelectMembershipsFromRoom(ctx, roomNID, localOnly) } type transaction struct { diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index c3fdb212b..11cff8a8b 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -103,3 +103,20 @@ type Invites interface { // SelectInviteActiveForUserInRoom returns a list of sender state key NIDs SelectInviteActiveForUserInRoom(ctx context.Context, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID) ([]types.EventStateKeyNID, error) } + +type MembershipState int64 + +const ( + MembershipStateLeaveOrBan MembershipState = 1 + MembershipStateInvite MembershipState = 2 + MembershipStateJoin MembershipState = 3 +) + +type Membership interface { + InsertMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, localTarget bool) error + SelectMembershipForUpdate(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (MembershipState, error) + SelectMembershipFromRoomAndTarget(ctx context.Context, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (types.EventNID, MembershipState, error) + SelectMembershipsFromRoom(ctx context.Context, roomNID types.RoomNID, localOnly bool) (eventNIDs []types.EventNID, err error) + SelectMembershipsFromRoomAndMembership(ctx context.Context, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error) + UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID) error +}