Convert invites table

This commit is contained in:
Kegan Dougal 2020-05-27 09:54:30 +01:00
parent 267a4d1823
commit 9bdbb79ccd
8 changed files with 54 additions and 41 deletions

View file

@ -20,6 +20,7 @@ import (
"database/sql" "database/sql"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
) )
@ -79,20 +80,21 @@ type inviteStatements struct {
updateInviteRetiredStmt *sql.Stmt updateInviteRetiredStmt *sql.Stmt
} }
func (s *inviteStatements) prepare(db *sql.DB) (err error) { func NewPostgresInvitesTable(db *sql.DB) (tables.Invites, error) {
_, err = db.Exec(inviteSchema) s := &inviteStatements{}
_, err := db.Exec(inviteSchema)
if err != nil { if err != nil {
return return nil, err
} }
return statementList{ return s, statementList{
{&s.insertInviteEventStmt, insertInviteEventSQL}, {&s.insertInviteEventStmt, insertInviteEventSQL},
{&s.selectInviteActiveForUserInRoomStmt, selectInviteActiveForUserInRoomSQL}, {&s.selectInviteActiveForUserInRoomStmt, selectInviteActiveForUserInRoomSQL},
{&s.updateInviteRetiredStmt, updateInviteRetiredSQL}, {&s.updateInviteRetiredStmt, updateInviteRetiredSQL},
}.prepare(db) }.prepare(db)
} }
func (s *inviteStatements) insertInviteEvent( func (s *inviteStatements) InsertInviteEvent(
ctx context.Context, ctx context.Context,
txn *sql.Tx, inviteEventID string, roomNID types.RoomNID, txn *sql.Tx, inviteEventID string, roomNID types.RoomNID,
targetUserNID, senderUserNID types.EventStateKeyNID, targetUserNID, senderUserNID types.EventStateKeyNID,
@ -111,7 +113,7 @@ func (s *inviteStatements) insertInviteEvent(
return count != 0, nil return count != 0, nil
} }
func (s *inviteStatements) updateInviteRetired( func (s *inviteStatements) UpdateInviteRetired(
ctx context.Context, ctx context.Context,
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) ([]string, error) { ) ([]string, error) {
@ -133,8 +135,8 @@ func (s *inviteStatements) updateInviteRetired(
return eventIDs, rows.Err() return eventIDs, rows.Err()
} }
// selectInviteActiveForUserInRoom returns a list of sender state key NIDs // SelectInviteActiveForUserInRoom returns a list of sender state key NIDs
func (s *inviteStatements) selectInviteActiveForUserInRoom( func (s *inviteStatements) SelectInviteActiveForUserInRoom(
ctx context.Context, ctx context.Context,
targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID,
) ([]types.EventStateKeyNID, error) { ) ([]types.EventStateKeyNID, error) {

View file

@ -38,7 +38,6 @@ func (s *statements) prepare(db *sql.DB) error {
var err error var err error
for _, prepare := range []func(db *sql.DB) error{ for _, prepare := range []func(db *sql.DB) error{
s.inviteStatements.prepare,
s.membershipStatements.prepare, s.membershipStatements.prepare,
} { } {
if err = prepare(db); err != nil { if err = prepare(db); err != nil {

View file

@ -41,6 +41,7 @@ type Database struct {
rooms tables.Rooms rooms tables.Rooms
transactions tables.Transactions transactions tables.Transactions
prevEvents tables.PreviousEvents prevEvents tables.PreviousEvents
invites tables.Invites
db *sql.DB db *sql.DB
} }
@ -95,6 +96,10 @@ func Open(dataSourceName string, dbProperties internal.DbProperties) (*Database,
if err != nil { if err != nil {
return nil, err return nil, err
} }
d.invites, err = NewPostgresInvitesTable(d.db)
if err != nil {
return nil, err
}
d.Database = shared.Database{ d.Database = shared.Database{
DB: d.db, DB: d.db,
EventTypesTable: d.eventTypes, EventTypesTable: d.eventTypes,
@ -107,6 +112,7 @@ func Open(dataSourceName string, dbProperties internal.DbProperties) (*Database,
StateSnapshotTable: stateSnapshot, StateSnapshotTable: stateSnapshot,
PrevEventsTable: d.prevEvents, PrevEventsTable: d.prevEvents,
RoomAliasesTable: roomAliases, RoomAliasesTable: roomAliases,
InvitesTable: d.invites,
} }
return &d, nil return &d, nil
} }
@ -254,15 +260,6 @@ func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventSta
return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomNID, targetUserNID, targetLocal) return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomNID, targetUserNID, targetLocal)
} }
// GetInvitesForUser implements query.RoomserverQueryAPIDatabase
func (d *Database) GetInvitesForUser(
ctx context.Context,
roomNID types.RoomNID,
targetUserNID types.EventStateKeyNID,
) (senderUserIDs []types.EventStateKeyNID, err error) {
return d.statements.selectInviteActiveForUserInRoom(ctx, targetUserNID, roomNID)
}
// MembershipUpdater implements input.RoomEventDatabase // MembershipUpdater implements input.RoomEventDatabase
func (d *Database) MembershipUpdater( func (d *Database) MembershipUpdater(
ctx context.Context, roomID, targetUserID string, ctx context.Context, roomID, targetUserID string,
@ -349,7 +346,7 @@ func (u *membershipUpdater) SetToInvite(event gomatrixserverlib.Event) (bool, er
if err != nil { if err != nil {
return false, err return false, err
} }
inserted, err := u.d.statements.insertInviteEvent( inserted, err := u.d.invites.InsertInviteEvent(
u.ctx, u.txn, event.EventID(), u.roomNID, u.targetUserNID, senderUserNID, event.JSON(), u.ctx, u.txn, event.EventID(), u.roomNID, u.targetUserNID, senderUserNID, event.JSON(),
) )
if err != nil { if err != nil {
@ -376,7 +373,7 @@ func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd
// If this is a join event update, there is no invite to update // If this is a join event update, there is no invite to update
if !isUpdate { if !isUpdate {
inviteEventIDs, err = u.d.statements.updateInviteRetired( inviteEventIDs, err = u.d.invites.UpdateInviteRetired(
u.ctx, u.txn, u.roomNID, u.targetUserNID, u.ctx, u.txn, u.roomNID, u.targetUserNID,
) )
if err != nil { if err != nil {
@ -408,7 +405,7 @@ func (u *membershipUpdater) SetToLeave(senderUserID string, eventID string) ([]s
if err != nil { if err != nil {
return nil, err return nil, err
} }
inviteEventIDs, err := u.d.statements.updateInviteRetired( inviteEventIDs, err := u.d.invites.UpdateInviteRetired(
u.ctx, u.txn, u.roomNID, u.targetUserNID, u.ctx, u.txn, u.roomNID, u.targetUserNID,
) )
if err != nil { if err != nil {

View file

@ -24,6 +24,7 @@ type Database struct {
StateBlockTable tables.StateBlock StateBlockTable tables.StateBlock
RoomAliasesTable tables.RoomAliases RoomAliasesTable tables.RoomAliases
PrevEventsTable tables.PreviousEvents PrevEventsTable tables.PreviousEvents
InvitesTable tables.Invites
} }
// EventTypeNIDs implements state.RoomStateDatabase // EventTypeNIDs implements state.RoomStateDatabase
@ -247,6 +248,15 @@ func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error {
return d.RoomAliasesTable.DeleteRoomAlias(ctx, alias) return d.RoomAliasesTable.DeleteRoomAlias(ctx, alias)
} }
// GetInvitesForUser implements query.RoomserverQueryAPIDatabase
func (d *Database) GetInvitesForUser(
ctx context.Context,
roomNID types.RoomNID,
targetUserNID types.EventStateKeyNID,
) (senderUserIDs []types.EventStateKeyNID, err error) {
return d.InvitesTable.SelectInviteActiveForUserInRoom(ctx, targetUserNID, roomNID)
}
// Events implements input.EventDatabase // Events implements input.EventDatabase
func (d *Database) Events( func (d *Database) Events(
ctx context.Context, eventNIDs []types.EventNID, ctx context.Context, eventNIDs []types.EventNID,

View file

@ -20,6 +20,7 @@ import (
"database/sql" "database/sql"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
) )
@ -66,13 +67,14 @@ type inviteStatements struct {
selectInvitesAboutToRetireStmt *sql.Stmt selectInvitesAboutToRetireStmt *sql.Stmt
} }
func (s *inviteStatements) prepare(db *sql.DB) (err error) { func NewSqliteInvitesTable(db *sql.DB) (tables.Invites, error) {
_, err = db.Exec(inviteSchema) s := &inviteStatements{}
_, err := db.Exec(inviteSchema)
if err != nil { if err != nil {
return return nil, err
} }
return statementList{ return s, statementList{
{&s.insertInviteEventStmt, insertInviteEventSQL}, {&s.insertInviteEventStmt, insertInviteEventSQL},
{&s.selectInviteActiveForUserInRoomStmt, selectInviteActiveForUserInRoomSQL}, {&s.selectInviteActiveForUserInRoomStmt, selectInviteActiveForUserInRoomSQL},
{&s.updateInviteRetiredStmt, updateInviteRetiredSQL}, {&s.updateInviteRetiredStmt, updateInviteRetiredSQL},
@ -80,7 +82,7 @@ func (s *inviteStatements) prepare(db *sql.DB) (err error) {
}.prepare(db) }.prepare(db)
} }
func (s *inviteStatements) insertInviteEvent( func (s *inviteStatements) InsertInviteEvent(
ctx context.Context, ctx context.Context,
txn *sql.Tx, inviteEventID string, roomNID types.RoomNID, txn *sql.Tx, inviteEventID string, roomNID types.RoomNID,
targetUserNID, senderUserNID types.EventStateKeyNID, targetUserNID, senderUserNID types.EventStateKeyNID,
@ -101,7 +103,7 @@ func (s *inviteStatements) insertInviteEvent(
return count != 0, nil return count != 0, nil
} }
func (s *inviteStatements) updateInviteRetired( func (s *inviteStatements) UpdateInviteRetired(
ctx context.Context, ctx context.Context,
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) (eventIDs []string, err error) { ) (eventIDs []string, err error) {
@ -127,7 +129,7 @@ func (s *inviteStatements) updateInviteRetired(
} }
// selectInviteActiveForUserInRoom returns a list of sender state key NIDs // selectInviteActiveForUserInRoom returns a list of sender state key NIDs
func (s *inviteStatements) selectInviteActiveForUserInRoom( func (s *inviteStatements) SelectInviteActiveForUserInRoom(
ctx context.Context, ctx context.Context,
targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID,
) ([]types.EventStateKeyNID, error) { ) ([]types.EventStateKeyNID, error) {

View file

@ -38,7 +38,6 @@ func (s *statements) prepare(db *sql.DB) error {
var err error var err error
for _, prepare := range []func(db *sql.DB) error{ for _, prepare := range []func(db *sql.DB) error{
s.inviteStatements.prepare,
s.membershipStatements.prepare, s.membershipStatements.prepare,
} { } {
if err = prepare(db); err != nil { if err = prepare(db); err != nil {

View file

@ -42,6 +42,7 @@ type Database struct {
rooms tables.Rooms rooms tables.Rooms
transactions tables.Transactions transactions tables.Transactions
prevEvents tables.PreviousEvents prevEvents tables.PreviousEvents
invites tables.Invites
db *sql.DB db *sql.DB
} }
@ -115,6 +116,10 @@ func Open(dataSourceName string) (*Database, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
d.invites, err = NewSqliteInvitesTable(d.db)
if err != nil {
return nil, err
}
d.Database = shared.Database{ d.Database = shared.Database{
DB: d.db, DB: d.db,
EventsTable: d.events, EventsTable: d.events,
@ -127,6 +132,7 @@ func Open(dataSourceName string) (*Database, error) {
StateSnapshotTable: stateSnapshot, StateSnapshotTable: stateSnapshot,
PrevEventsTable: d.prevEvents, PrevEventsTable: d.prevEvents,
RoomAliasesTable: roomAliases, RoomAliasesTable: roomAliases,
InvitesTable: d.invites,
} }
return &d, nil return &d, nil
} }
@ -305,15 +311,6 @@ func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventSta
return return
} }
// GetInvitesForUser implements query.RoomserverQueryAPIDatabase
func (d *Database) GetInvitesForUser(
ctx context.Context,
roomNID types.RoomNID,
targetUserNID types.EventStateKeyNID,
) (senderUserIDs []types.EventStateKeyNID, err error) {
return d.statements.selectInviteActiveForUserInRoom(ctx, targetUserNID, roomNID)
}
// MembershipUpdater implements input.RoomEventDatabase // MembershipUpdater implements input.RoomEventDatabase
func (d *Database) MembershipUpdater( func (d *Database) MembershipUpdater(
ctx context.Context, roomID, targetUserID string, ctx context.Context, roomID, targetUserID string,
@ -415,7 +412,7 @@ func (u *membershipUpdater) SetToInvite(event gomatrixserverlib.Event) (inserted
if err != nil { if err != nil {
return err return err
} }
inserted, err = u.d.statements.insertInviteEvent( inserted, err = u.d.invites.InsertInviteEvent(
u.ctx, txn, event.EventID(), u.roomNID, u.targetUserNID, senderUserNID, event.JSON(), u.ctx, txn, event.EventID(), u.roomNID, u.targetUserNID, senderUserNID, event.JSON(),
) )
if err != nil { if err != nil {
@ -443,7 +440,7 @@ func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd
// If this is a join event update, there is no invite to update // If this is a join event update, there is no invite to update
if !isUpdate { if !isUpdate {
inviteEventIDs, err = u.d.statements.updateInviteRetired( inviteEventIDs, err = u.d.invites.UpdateInviteRetired(
u.ctx, txn, u.roomNID, u.targetUserNID, u.ctx, txn, u.roomNID, u.targetUserNID,
) )
if err != nil { if err != nil {
@ -478,7 +475,7 @@ func (u *membershipUpdater) SetToLeave(senderUserID string, eventID string) (inv
if err != nil { if err != nil {
return err return err
} }
inviteEventIDs, err = u.d.statements.updateInviteRetired( inviteEventIDs, err = u.d.invites.UpdateInviteRetired(
u.ctx, txn, u.roomNID, u.targetUserNID, u.ctx, txn, u.roomNID, u.targetUserNID,
) )
if err != nil { if err != nil {

View file

@ -96,3 +96,10 @@ type PreviousEvents interface {
// Returns sql.ErrNoRows if the event reference doesn't exist. // Returns sql.ErrNoRows if the event reference doesn't exist.
SelectPreviousEventExists(ctx context.Context, txn *sql.Tx, eventID string, eventReferenceSHA256 []byte) error SelectPreviousEventExists(ctx context.Context, txn *sql.Tx, eventID string, eventReferenceSHA256 []byte) error
} }
type Invites interface {
InsertInviteEvent(ctx context.Context, txn *sql.Tx, inviteEventID string, roomNID types.RoomNID, targetUserNID, senderUserNID types.EventStateKeyNID, inviteEventJSON []byte) (bool, error)
UpdateInviteRetired(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) ([]string, error)
// SelectInviteActiveForUserInRoom returns a list of sender state key NIDs
SelectInviteActiveForUserInRoom(ctx context.Context, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID) ([]types.EventStateKeyNID, error)
}