From 9bdbb79ccd04d091b0a8077f9a21e8ba7bee4614 Mon Sep 17 00:00:00 2001 From: Kegan Dougal Date: Wed, 27 May 2020 09:54:30 +0100 Subject: [PATCH] Convert invites table --- roomserver/storage/postgres/invite_table.go | 18 ++++++++++-------- roomserver/storage/postgres/sql.go | 1 - roomserver/storage/postgres/storage.go | 21 +++++++++------------ roomserver/storage/shared/storage.go | 10 ++++++++++ roomserver/storage/sqlite3/invite_table.go | 16 +++++++++------- roomserver/storage/sqlite3/sql.go | 1 - roomserver/storage/sqlite3/storage.go | 21 +++++++++------------ roomserver/storage/tables/interface.go | 7 +++++++ 8 files changed, 54 insertions(+), 41 deletions(-) diff --git a/roomserver/storage/postgres/invite_table.go b/roomserver/storage/postgres/invite_table.go index f0fb919e6..4f1a6c63b 100644 --- a/roomserver/storage/postgres/invite_table.go +++ b/roomserver/storage/postgres/invite_table.go @@ -20,6 +20,7 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -79,20 +80,21 @@ type inviteStatements struct { updateInviteRetiredStmt *sql.Stmt } -func (s *inviteStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(inviteSchema) +func NewPostgresInvitesTable(db *sql.DB) (tables.Invites, error) { + s := &inviteStatements{} + _, err := db.Exec(inviteSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, statementList{ {&s.insertInviteEventStmt, insertInviteEventSQL}, {&s.selectInviteActiveForUserInRoomStmt, selectInviteActiveForUserInRoomSQL}, {&s.updateInviteRetiredStmt, updateInviteRetiredSQL}, }.prepare(db) } -func (s *inviteStatements) insertInviteEvent( +func (s *inviteStatements) InsertInviteEvent( ctx context.Context, txn *sql.Tx, inviteEventID string, roomNID types.RoomNID, targetUserNID, senderUserNID types.EventStateKeyNID, @@ -111,7 +113,7 @@ func (s *inviteStatements) insertInviteEvent( return count != 0, nil } -func (s *inviteStatements) updateInviteRetired( +func (s *inviteStatements) UpdateInviteRetired( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, ) ([]string, error) { @@ -133,8 +135,8 @@ func (s *inviteStatements) updateInviteRetired( return eventIDs, rows.Err() } -// selectInviteActiveForUserInRoom returns a list of sender state key NIDs -func (s *inviteStatements) selectInviteActiveForUserInRoom( +// SelectInviteActiveForUserInRoom returns a list of sender state key NIDs +func (s *inviteStatements) SelectInviteActiveForUserInRoom( ctx context.Context, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, ) ([]types.EventStateKeyNID, error) { diff --git a/roomserver/storage/postgres/sql.go b/roomserver/storage/postgres/sql.go index eb626dd88..1a84508ae 100644 --- a/roomserver/storage/postgres/sql.go +++ b/roomserver/storage/postgres/sql.go @@ -38,7 +38,6 @@ func (s *statements) prepare(db *sql.DB) error { var err error for _, prepare := range []func(db *sql.DB) error{ - s.inviteStatements.prepare, s.membershipStatements.prepare, } { if err = prepare(db); err != nil { diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go index 03cfb7f0e..992fc1dfb 100644 --- a/roomserver/storage/postgres/storage.go +++ b/roomserver/storage/postgres/storage.go @@ -41,6 +41,7 @@ type Database struct { rooms tables.Rooms transactions tables.Transactions prevEvents tables.PreviousEvents + invites tables.Invites db *sql.DB } @@ -95,6 +96,10 @@ func Open(dataSourceName string, dbProperties internal.DbProperties) (*Database, if err != nil { return nil, err } + d.invites, err = NewPostgresInvitesTable(d.db) + if err != nil { + return nil, err + } d.Database = shared.Database{ DB: d.db, EventTypesTable: d.eventTypes, @@ -107,6 +112,7 @@ func Open(dataSourceName string, dbProperties internal.DbProperties) (*Database, StateSnapshotTable: stateSnapshot, PrevEventsTable: d.prevEvents, RoomAliasesTable: roomAliases, + InvitesTable: d.invites, } return &d, nil } @@ -254,15 +260,6 @@ func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventSta return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomNID, targetUserNID, targetLocal) } -// GetInvitesForUser implements query.RoomserverQueryAPIDatabase -func (d *Database) GetInvitesForUser( - ctx context.Context, - roomNID types.RoomNID, - targetUserNID types.EventStateKeyNID, -) (senderUserIDs []types.EventStateKeyNID, err error) { - return d.statements.selectInviteActiveForUserInRoom(ctx, targetUserNID, roomNID) -} - // MembershipUpdater implements input.RoomEventDatabase func (d *Database) MembershipUpdater( ctx context.Context, roomID, targetUserID string, @@ -349,7 +346,7 @@ func (u *membershipUpdater) SetToInvite(event gomatrixserverlib.Event) (bool, er if err != nil { return false, err } - inserted, err := u.d.statements.insertInviteEvent( + inserted, err := u.d.invites.InsertInviteEvent( u.ctx, u.txn, event.EventID(), u.roomNID, u.targetUserNID, senderUserNID, event.JSON(), ) if err != nil { @@ -376,7 +373,7 @@ func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd // If this is a join event update, there is no invite to update if !isUpdate { - inviteEventIDs, err = u.d.statements.updateInviteRetired( + inviteEventIDs, err = u.d.invites.UpdateInviteRetired( u.ctx, u.txn, u.roomNID, u.targetUserNID, ) if err != nil { @@ -408,7 +405,7 @@ func (u *membershipUpdater) SetToLeave(senderUserID string, eventID string) ([]s if err != nil { return nil, err } - inviteEventIDs, err := u.d.statements.updateInviteRetired( + inviteEventIDs, err := u.d.invites.UpdateInviteRetired( u.ctx, u.txn, u.roomNID, u.targetUserNID, ) if err != nil { diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 29c6f73eb..311afbeb8 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -24,6 +24,7 @@ type Database struct { StateBlockTable tables.StateBlock RoomAliasesTable tables.RoomAliases PrevEventsTable tables.PreviousEvents + InvitesTable tables.Invites } // EventTypeNIDs implements state.RoomStateDatabase @@ -247,6 +248,15 @@ func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error { return d.RoomAliasesTable.DeleteRoomAlias(ctx, alias) } +// GetInvitesForUser implements query.RoomserverQueryAPIDatabase +func (d *Database) GetInvitesForUser( + ctx context.Context, + roomNID types.RoomNID, + targetUserNID types.EventStateKeyNID, +) (senderUserIDs []types.EventStateKeyNID, err error) { + return d.InvitesTable.SelectInviteActiveForUserInRoom(ctx, targetUserNID, roomNID) +} + // Events implements input.EventDatabase func (d *Database) Events( ctx context.Context, eventNIDs []types.EventNID, diff --git a/roomserver/storage/sqlite3/invite_table.go b/roomserver/storage/sqlite3/invite_table.go index a42d18a73..36da7cfff 100644 --- a/roomserver/storage/sqlite3/invite_table.go +++ b/roomserver/storage/sqlite3/invite_table.go @@ -20,6 +20,7 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -66,13 +67,14 @@ type inviteStatements struct { selectInvitesAboutToRetireStmt *sql.Stmt } -func (s *inviteStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(inviteSchema) +func NewSqliteInvitesTable(db *sql.DB) (tables.Invites, error) { + s := &inviteStatements{} + _, err := db.Exec(inviteSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, statementList{ {&s.insertInviteEventStmt, insertInviteEventSQL}, {&s.selectInviteActiveForUserInRoomStmt, selectInviteActiveForUserInRoomSQL}, {&s.updateInviteRetiredStmt, updateInviteRetiredSQL}, @@ -80,7 +82,7 @@ func (s *inviteStatements) prepare(db *sql.DB) (err error) { }.prepare(db) } -func (s *inviteStatements) insertInviteEvent( +func (s *inviteStatements) InsertInviteEvent( ctx context.Context, txn *sql.Tx, inviteEventID string, roomNID types.RoomNID, targetUserNID, senderUserNID types.EventStateKeyNID, @@ -101,7 +103,7 @@ func (s *inviteStatements) insertInviteEvent( return count != 0, nil } -func (s *inviteStatements) updateInviteRetired( +func (s *inviteStatements) UpdateInviteRetired( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, ) (eventIDs []string, err error) { @@ -127,7 +129,7 @@ func (s *inviteStatements) updateInviteRetired( } // selectInviteActiveForUserInRoom returns a list of sender state key NIDs -func (s *inviteStatements) selectInviteActiveForUserInRoom( +func (s *inviteStatements) SelectInviteActiveForUserInRoom( ctx context.Context, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, ) ([]types.EventStateKeyNID, error) { diff --git a/roomserver/storage/sqlite3/sql.go b/roomserver/storage/sqlite3/sql.go index df994d508..e07fc6465 100644 --- a/roomserver/storage/sqlite3/sql.go +++ b/roomserver/storage/sqlite3/sql.go @@ -38,7 +38,6 @@ func (s *statements) prepare(db *sql.DB) error { var err error for _, prepare := range []func(db *sql.DB) error{ - s.inviteStatements.prepare, s.membershipStatements.prepare, } { if err = prepare(db); err != nil { diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index 16697f1b9..0b5a7469e 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -42,6 +42,7 @@ type Database struct { rooms tables.Rooms transactions tables.Transactions prevEvents tables.PreviousEvents + invites tables.Invites db *sql.DB } @@ -115,6 +116,10 @@ func Open(dataSourceName string) (*Database, error) { if err != nil { return nil, err } + d.invites, err = NewSqliteInvitesTable(d.db) + if err != nil { + return nil, err + } d.Database = shared.Database{ DB: d.db, EventsTable: d.events, @@ -127,6 +132,7 @@ func Open(dataSourceName string) (*Database, error) { StateSnapshotTable: stateSnapshot, PrevEventsTable: d.prevEvents, RoomAliasesTable: roomAliases, + InvitesTable: d.invites, } return &d, nil } @@ -305,15 +311,6 @@ func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventSta return } -// GetInvitesForUser implements query.RoomserverQueryAPIDatabase -func (d *Database) GetInvitesForUser( - ctx context.Context, - roomNID types.RoomNID, - targetUserNID types.EventStateKeyNID, -) (senderUserIDs []types.EventStateKeyNID, err error) { - return d.statements.selectInviteActiveForUserInRoom(ctx, targetUserNID, roomNID) -} - // MembershipUpdater implements input.RoomEventDatabase func (d *Database) MembershipUpdater( ctx context.Context, roomID, targetUserID string, @@ -415,7 +412,7 @@ func (u *membershipUpdater) SetToInvite(event gomatrixserverlib.Event) (inserted if err != nil { return err } - inserted, err = u.d.statements.insertInviteEvent( + inserted, err = u.d.invites.InsertInviteEvent( u.ctx, txn, event.EventID(), u.roomNID, u.targetUserNID, senderUserNID, event.JSON(), ) if err != nil { @@ -443,7 +440,7 @@ func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd // If this is a join event update, there is no invite to update if !isUpdate { - inviteEventIDs, err = u.d.statements.updateInviteRetired( + inviteEventIDs, err = u.d.invites.UpdateInviteRetired( u.ctx, txn, u.roomNID, u.targetUserNID, ) if err != nil { @@ -478,7 +475,7 @@ func (u *membershipUpdater) SetToLeave(senderUserID string, eventID string) (inv if err != nil { return err } - inviteEventIDs, err = u.d.statements.updateInviteRetired( + inviteEventIDs, err = u.d.invites.UpdateInviteRetired( u.ctx, txn, u.roomNID, u.targetUserNID, ) if err != nil { diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index 56ef5dd39..c3fdb212b 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -96,3 +96,10 @@ type PreviousEvents interface { // Returns sql.ErrNoRows if the event reference doesn't exist. SelectPreviousEventExists(ctx context.Context, txn *sql.Tx, eventID string, eventReferenceSHA256 []byte) error } + +type Invites interface { + InsertInviteEvent(ctx context.Context, txn *sql.Tx, inviteEventID string, roomNID types.RoomNID, targetUserNID, senderUserNID types.EventStateKeyNID, inviteEventJSON []byte) (bool, error) + UpdateInviteRetired(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) ([]string, error) + // SelectInviteActiveForUserInRoom returns a list of sender state key NIDs + SelectInviteActiveForUserInRoom(ctx context.Context, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID) ([]types.EventStateKeyNID, error) +}