Convert membership table

This commit is contained in:
Kegan Dougal 2020-05-27 10:16:33 +01:00
parent 9bdbb79ccd
commit eebef093dc
8 changed files with 115 additions and 207 deletions

View file

@ -20,17 +20,10 @@ import (
"database/sql" "database/sql"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
) )
type membershipState int64
const (
membershipStateLeaveOrBan membershipState = 1
membershipStateInvite membershipState = 2
membershipStateJoin membershipState = 3
)
const membershipSchema = ` const membershipSchema = `
-- The membership table is used to coordinate updates between the invite table -- The membership table is used to coordinate updates between the invite table
-- and the room state tables. -- and the room state tables.
@ -115,13 +108,14 @@ type membershipStatements struct {
updateMembershipStmt *sql.Stmt updateMembershipStmt *sql.Stmt
} }
func (s *membershipStatements) prepare(db *sql.DB) (err error) { func NewPostgresMembershipTable(db *sql.DB) (tables.Membership, error) {
_, err = db.Exec(membershipSchema) s := &membershipStatements{}
_, err := db.Exec(membershipSchema)
if err != nil { if err != nil {
return return nil, err
} }
return statementList{ return s, statementList{
{&s.insertMembershipStmt, insertMembershipSQL}, {&s.insertMembershipStmt, insertMembershipSQL},
{&s.selectMembershipForUpdateStmt, selectMembershipForUpdateSQL}, {&s.selectMembershipForUpdateStmt, selectMembershipForUpdateSQL},
{&s.selectMembershipFromRoomAndTargetStmt, selectMembershipFromRoomAndTargetSQL}, {&s.selectMembershipFromRoomAndTargetStmt, selectMembershipFromRoomAndTargetSQL},
@ -133,7 +127,7 @@ func (s *membershipStatements) prepare(db *sql.DB) (err error) {
}.prepare(db) }.prepare(db)
} }
func (s *membershipStatements) insertMembership( func (s *membershipStatements) InsertMembership(
ctx context.Context, ctx context.Context,
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
localTarget bool, localTarget bool,
@ -143,27 +137,27 @@ func (s *membershipStatements) insertMembership(
return err return err
} }
func (s *membershipStatements) selectMembershipForUpdate( func (s *membershipStatements) SelectMembershipForUpdate(
ctx context.Context, ctx context.Context,
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) (membership membershipState, err error) { ) (membership tables.MembershipState, err error) {
err = internal.TxStmt(txn, s.selectMembershipForUpdateStmt).QueryRowContext( err = internal.TxStmt(txn, s.selectMembershipForUpdateStmt).QueryRowContext(
ctx, roomNID, targetUserNID, ctx, roomNID, targetUserNID,
).Scan(&membership) ).Scan(&membership)
return return
} }
func (s *membershipStatements) selectMembershipFromRoomAndTarget( func (s *membershipStatements) SelectMembershipFromRoomAndTarget(
ctx context.Context, ctx context.Context,
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) (eventNID types.EventNID, membership membershipState, err error) { ) (eventNID types.EventNID, membership tables.MembershipState, err error) {
err = s.selectMembershipFromRoomAndTargetStmt.QueryRowContext( err = s.selectMembershipFromRoomAndTargetStmt.QueryRowContext(
ctx, roomNID, targetUserNID, ctx, roomNID, targetUserNID,
).Scan(&membership, &eventNID) ).Scan(&membership, &eventNID)
return return
} }
func (s *membershipStatements) selectMembershipsFromRoom( func (s *membershipStatements) SelectMembershipsFromRoom(
ctx context.Context, roomNID types.RoomNID, localOnly bool, ctx context.Context, roomNID types.RoomNID, localOnly bool,
) (eventNIDs []types.EventNID, err error) { ) (eventNIDs []types.EventNID, err error) {
var stmt *sql.Stmt var stmt *sql.Stmt
@ -188,9 +182,9 @@ func (s *membershipStatements) selectMembershipsFromRoom(
return eventNIDs, rows.Err() return eventNIDs, rows.Err()
} }
func (s *membershipStatements) selectMembershipsFromRoomAndMembership( func (s *membershipStatements) SelectMembershipsFromRoomAndMembership(
ctx context.Context, ctx context.Context,
roomNID types.RoomNID, membership membershipState, localOnly bool, roomNID types.RoomNID, membership tables.MembershipState, localOnly bool,
) (eventNIDs []types.EventNID, err error) { ) (eventNIDs []types.EventNID, err error) {
var rows *sql.Rows var rows *sql.Rows
var stmt *sql.Stmt var stmt *sql.Stmt
@ -215,10 +209,10 @@ func (s *membershipStatements) selectMembershipsFromRoomAndMembership(
return eventNIDs, rows.Err() return eventNIDs, rows.Err()
} }
func (s *membershipStatements) updateMembership( func (s *membershipStatements) UpdateMembership(
ctx context.Context, ctx context.Context,
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
senderUserNID types.EventStateKeyNID, membership membershipState, senderUserNID types.EventStateKeyNID, membership tables.MembershipState,
eventNID types.EventNID, eventNID types.EventNID,
) error { ) error {
_, err := internal.TxStmt(txn, s.updateMembershipStmt).ExecContext( _, err := internal.TxStmt(txn, s.updateMembershipStmt).ExecContext(

View file

@ -1,49 +0,0 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"database/sql"
)
type statements struct {
eventTypeStatements
eventStateKeyStatements
roomStatements
eventStatements
eventJSONStatements
stateSnapshotStatements
stateBlockStatements
previousEventStatements
roomAliasesStatements
inviteStatements
membershipStatements
transactionStatements
}
func (s *statements) prepare(db *sql.DB) error {
var err error
for _, prepare := range []func(db *sql.DB) error{
s.membershipStatements.prepare,
} {
if err = prepare(db); err != nil {
return err
}
}
return nil
}

View file

@ -33,7 +33,6 @@ import (
// A Database is used to store room events and stream offsets. // A Database is used to store room events and stream offsets.
type Database struct { type Database struct {
shared.Database shared.Database
statements statements
events tables.Events events tables.Events
eventTypes tables.EventTypes eventTypes tables.EventTypes
eventStateKeys tables.EventStateKeys eventStateKeys tables.EventStateKeys
@ -42,6 +41,7 @@ type Database struct {
transactions tables.Transactions transactions tables.Transactions
prevEvents tables.PreviousEvents prevEvents tables.PreviousEvents
invites tables.Invites invites tables.Invites
membership tables.Membership
db *sql.DB db *sql.DB
} }
@ -53,9 +53,6 @@ func Open(dataSourceName string, dbProperties internal.DbProperties) (*Database,
if d.db, err = sqlutil.Open("postgres", dataSourceName, dbProperties); err != nil { if d.db, err = sqlutil.Open("postgres", dataSourceName, dbProperties); err != nil {
return nil, err return nil, err
} }
if err = d.statements.prepare(d.db); err != nil {
return nil, err
}
d.eventStateKeys, err = NewPostgresEventStateKeysTable(d.db) d.eventStateKeys, err = NewPostgresEventStateKeysTable(d.db)
if err != nil { if err != nil {
return nil, err return nil, err
@ -100,6 +97,10 @@ func Open(dataSourceName string, dbProperties internal.DbProperties) (*Database,
if err != nil { if err != nil {
return nil, err return nil, err
} }
d.membership, err = NewPostgresMembershipTable(d.db)
if err != nil {
return nil, err
}
d.Database = shared.Database{ d.Database = shared.Database{
DB: d.db, DB: d.db,
EventTypesTable: d.eventTypes, EventTypesTable: d.eventTypes,
@ -113,6 +114,7 @@ func Open(dataSourceName string, dbProperties internal.DbProperties) (*Database,
PrevEventsTable: d.prevEvents, PrevEventsTable: d.prevEvents,
RoomAliasesTable: roomAliases, RoomAliasesTable: roomAliases,
InvitesTable: d.invites, InvitesTable: d.invites,
MembershipTable: d.membership,
} }
return &d, nil return &d, nil
} }
@ -300,7 +302,7 @@ type membershipUpdater struct {
d *Database d *Database
roomNID types.RoomNID roomNID types.RoomNID
targetUserNID types.EventStateKeyNID targetUserNID types.EventStateKeyNID
membership membershipState membership tables.MembershipState
} }
func (d *Database) membershipUpdaterTxn( func (d *Database) membershipUpdaterTxn(
@ -311,11 +313,11 @@ func (d *Database) membershipUpdaterTxn(
targetLocal bool, targetLocal bool,
) (types.MembershipUpdater, error) { ) (types.MembershipUpdater, error) {
if err := d.statements.insertMembership(ctx, txn, roomNID, targetUserNID, targetLocal); err != nil { if err := d.membership.InsertMembership(ctx, txn, roomNID, targetUserNID, targetLocal); err != nil {
return nil, err return nil, err
} }
membership, err := d.statements.selectMembershipForUpdate(ctx, txn, roomNID, targetUserNID) membership, err := d.membership.SelectMembershipForUpdate(ctx, txn, roomNID, targetUserNID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -327,17 +329,17 @@ func (d *Database) membershipUpdaterTxn(
// IsInvite implements types.MembershipUpdater // IsInvite implements types.MembershipUpdater
func (u *membershipUpdater) IsInvite() bool { func (u *membershipUpdater) IsInvite() bool {
return u.membership == membershipStateInvite return u.membership == tables.MembershipStateInvite
} }
// IsJoin implements types.MembershipUpdater // IsJoin implements types.MembershipUpdater
func (u *membershipUpdater) IsJoin() bool { func (u *membershipUpdater) IsJoin() bool {
return u.membership == membershipStateJoin return u.membership == tables.MembershipStateJoin
} }
// IsLeave implements types.MembershipUpdater // IsLeave implements types.MembershipUpdater
func (u *membershipUpdater) IsLeave() bool { func (u *membershipUpdater) IsLeave() bool {
return u.membership == membershipStateLeaveOrBan return u.membership == tables.MembershipStateLeaveOrBan
} }
// SetToInvite implements types.MembershipUpdater // SetToInvite implements types.MembershipUpdater
@ -352,9 +354,9 @@ func (u *membershipUpdater) SetToInvite(event gomatrixserverlib.Event) (bool, er
if err != nil { if err != nil {
return false, err return false, err
} }
if u.membership != membershipStateInvite { if u.membership != tables.MembershipStateInvite {
if err = u.d.statements.updateMembership( if err = u.d.membership.UpdateMembership(
u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, membershipStateInvite, 0, u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateInvite, 0,
); err != nil { ); err != nil {
return false, err return false, err
} }
@ -387,10 +389,10 @@ func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd
return nil, err return nil, err
} }
if u.membership != membershipStateJoin || isUpdate { if u.membership != tables.MembershipStateJoin || isUpdate {
if err = u.d.statements.updateMembership( if err = u.d.membership.UpdateMembership(
u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID,
membershipStateJoin, nIDs[eventID], tables.MembershipStateJoin, nIDs[eventID],
); err != nil { ); err != nil {
return nil, err return nil, err
} }
@ -418,10 +420,10 @@ func (u *membershipUpdater) SetToLeave(senderUserID string, eventID string) ([]s
return nil, err return nil, err
} }
if u.membership != membershipStateLeaveOrBan { if u.membership != tables.MembershipStateLeaveOrBan {
if err = u.d.statements.updateMembership( if err = u.d.membership.UpdateMembership(
u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID,
membershipStateLeaveOrBan, nIDs[eventID], tables.MembershipStateLeaveOrBan, nIDs[eventID],
); err != nil { ); err != nil {
return nil, err return nil, err
} }
@ -439,7 +441,7 @@ func (d *Database) GetMembership(
} }
senderMembershipEventNID, senderMembership, err := senderMembershipEventNID, senderMembership, err :=
d.statements.selectMembershipFromRoomAndTarget( d.membership.SelectMembershipFromRoomAndTarget(
ctx, roomNID, requestSenderUserNID, ctx, roomNID, requestSenderUserNID,
) )
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
@ -449,7 +451,7 @@ func (d *Database) GetMembership(
return return
} }
return senderMembershipEventNID, senderMembership == membershipStateJoin, nil return senderMembershipEventNID, senderMembership == tables.MembershipStateJoin, nil
} }
// GetMembershipEventNIDsForRoom implements query.RoomserverQueryAPIDB // GetMembershipEventNIDsForRoom implements query.RoomserverQueryAPIDB
@ -457,12 +459,12 @@ func (d *Database) GetMembershipEventNIDsForRoom(
ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool, ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool,
) ([]types.EventNID, error) { ) ([]types.EventNID, error) {
if joinOnly { if joinOnly {
return d.statements.selectMembershipsFromRoomAndMembership( return d.membership.SelectMembershipsFromRoomAndMembership(
ctx, roomNID, membershipStateJoin, localOnly, ctx, roomNID, tables.MembershipStateJoin, localOnly,
) )
} }
return d.statements.selectMembershipsFromRoom(ctx, roomNID, localOnly) return d.membership.SelectMembershipsFromRoom(ctx, roomNID, localOnly)
} }
type transaction struct { type transaction struct {

View file

@ -25,6 +25,7 @@ type Database struct {
RoomAliasesTable tables.RoomAliases RoomAliasesTable tables.RoomAliases
PrevEventsTable tables.PreviousEvents PrevEventsTable tables.PreviousEvents
InvitesTable tables.Invites InvitesTable tables.Invites
MembershipTable tables.Membership
} }
// EventTypeNIDs implements state.RoomStateDatabase // EventTypeNIDs implements state.RoomStateDatabase

View file

@ -20,17 +20,10 @@ import (
"database/sql" "database/sql"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
) )
type membershipState int64
const (
membershipStateLeaveOrBan membershipState = 1
membershipStateInvite membershipState = 2
membershipStateJoin membershipState = 3
)
const membershipSchema = ` const membershipSchema = `
CREATE TABLE IF NOT EXISTS roomserver_membership ( CREATE TABLE IF NOT EXISTS roomserver_membership (
room_nid INTEGER NOT NULL, room_nid INTEGER NOT NULL,
@ -91,13 +84,14 @@ type membershipStatements struct {
updateMembershipStmt *sql.Stmt updateMembershipStmt *sql.Stmt
} }
func (s *membershipStatements) prepare(db *sql.DB) (err error) { func NewSqliteMembershipTable(db *sql.DB) (tables.Membership, error) {
_, err = db.Exec(membershipSchema) s := &membershipStatements{}
_, err := db.Exec(membershipSchema)
if err != nil { if err != nil {
return return nil, err
} }
return statementList{ return s, statementList{
{&s.insertMembershipStmt, insertMembershipSQL}, {&s.insertMembershipStmt, insertMembershipSQL},
{&s.selectMembershipForUpdateStmt, selectMembershipForUpdateSQL}, {&s.selectMembershipForUpdateStmt, selectMembershipForUpdateSQL},
{&s.selectMembershipFromRoomAndTargetStmt, selectMembershipFromRoomAndTargetSQL}, {&s.selectMembershipFromRoomAndTargetStmt, selectMembershipFromRoomAndTargetSQL},
@ -109,7 +103,7 @@ func (s *membershipStatements) prepare(db *sql.DB) (err error) {
}.prepare(db) }.prepare(db)
} }
func (s *membershipStatements) insertMembership( func (s *membershipStatements) InsertMembership(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
localTarget bool, localTarget bool,
@ -119,10 +113,10 @@ func (s *membershipStatements) insertMembership(
return err return err
} }
func (s *membershipStatements) selectMembershipForUpdate( func (s *membershipStatements) SelectMembershipForUpdate(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) (membership membershipState, err error) { ) (membership tables.MembershipState, err error) {
stmt := internal.TxStmt(txn, s.selectMembershipForUpdateStmt) stmt := internal.TxStmt(txn, s.selectMembershipForUpdateStmt)
err = stmt.QueryRowContext( err = stmt.QueryRowContext(
ctx, roomNID, targetUserNID, ctx, roomNID, targetUserNID,
@ -130,26 +124,25 @@ func (s *membershipStatements) selectMembershipForUpdate(
return return
} }
func (s *membershipStatements) selectMembershipFromRoomAndTarget( func (s *membershipStatements) SelectMembershipFromRoomAndTarget(
ctx context.Context, txn *sql.Tx, ctx context.Context,
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) (eventNID types.EventNID, membership membershipState, err error) { ) (eventNID types.EventNID, membership tables.MembershipState, err error) {
selectStmt := internal.TxStmt(txn, s.selectMembershipFromRoomAndTargetStmt) err = s.selectMembershipFromRoomAndTargetStmt.QueryRowContext(
err = selectStmt.QueryRowContext(
ctx, roomNID, targetUserNID, ctx, roomNID, targetUserNID,
).Scan(&membership, &eventNID) ).Scan(&membership, &eventNID)
return return
} }
func (s *membershipStatements) selectMembershipsFromRoom( func (s *membershipStatements) SelectMembershipsFromRoom(
ctx context.Context, txn *sql.Tx, ctx context.Context,
roomNID types.RoomNID, localOnly bool, roomNID types.RoomNID, localOnly bool,
) (eventNIDs []types.EventNID, err error) { ) (eventNIDs []types.EventNID, err error) {
var selectStmt *sql.Stmt var selectStmt *sql.Stmt
if localOnly { if localOnly {
selectStmt = internal.TxStmt(txn, s.selectLocalMembershipsFromRoomStmt) selectStmt = s.selectLocalMembershipsFromRoomStmt
} else { } else {
selectStmt = internal.TxStmt(txn, s.selectMembershipsFromRoomStmt) selectStmt = s.selectMembershipsFromRoomStmt
} }
rows, err := selectStmt.QueryContext(ctx, roomNID) rows, err := selectStmt.QueryContext(ctx, roomNID)
if err != nil { if err != nil {
@ -167,15 +160,15 @@ func (s *membershipStatements) selectMembershipsFromRoom(
return return
} }
func (s *membershipStatements) selectMembershipsFromRoomAndMembership( func (s *membershipStatements) SelectMembershipsFromRoomAndMembership(
ctx context.Context, txn *sql.Tx, ctx context.Context,
roomNID types.RoomNID, membership membershipState, localOnly bool, roomNID types.RoomNID, membership tables.MembershipState, localOnly bool,
) (eventNIDs []types.EventNID, err error) { ) (eventNIDs []types.EventNID, err error) {
var stmt *sql.Stmt var stmt *sql.Stmt
if localOnly { if localOnly {
stmt = internal.TxStmt(txn, s.selectLocalMembershipsFromRoomAndMembershipStmt) stmt = s.selectLocalMembershipsFromRoomAndMembershipStmt
} else { } else {
stmt = internal.TxStmt(txn, s.selectMembershipsFromRoomAndMembershipStmt) stmt = s.selectMembershipsFromRoomAndMembershipStmt
} }
rows, err := stmt.QueryContext(ctx, roomNID, membership) rows, err := stmt.QueryContext(ctx, roomNID, membership)
if err != nil { if err != nil {
@ -193,10 +186,10 @@ func (s *membershipStatements) selectMembershipsFromRoomAndMembership(
return return
} }
func (s *membershipStatements) updateMembership( func (s *membershipStatements) UpdateMembership(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
senderUserNID types.EventStateKeyNID, membership membershipState, senderUserNID types.EventStateKeyNID, membership tables.MembershipState,
eventNID types.EventNID, eventNID types.EventNID,
) error { ) error {
stmt := internal.TxStmt(txn, s.updateMembershipStmt) stmt := internal.TxStmt(txn, s.updateMembershipStmt)

View file

@ -1,49 +0,0 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"database/sql"
)
type statements struct {
eventTypeStatements
eventStateKeyStatements
roomStatements
eventStatements
eventJSONStatements
stateSnapshotStatements
stateBlockStatements
previousEventStatements
roomAliasesStatements
inviteStatements
membershipStatements
transactionStatements
}
func (s *statements) prepare(db *sql.DB) error {
var err error
for _, prepare := range []func(db *sql.DB) error{
s.membershipStatements.prepare,
} {
if err = prepare(db); err != nil {
return err
}
}
return nil
}

View file

@ -34,7 +34,6 @@ import (
// A Database is used to store room events and stream offsets. // A Database is used to store room events and stream offsets.
type Database struct { type Database struct {
shared.Database shared.Database
statements statements
events tables.Events events tables.Events
eventJSON tables.EventJSON eventJSON tables.EventJSON
eventTypes tables.EventTypes eventTypes tables.EventTypes
@ -43,6 +42,7 @@ type Database struct {
transactions tables.Transactions transactions tables.Transactions
prevEvents tables.PreviousEvents prevEvents tables.PreviousEvents
invites tables.Invites invites tables.Invites
membership tables.Membership
db *sql.DB db *sql.DB
} }
@ -73,9 +73,7 @@ func Open(dataSourceName string) (*Database, error) {
// acquire the global mutex and never unlock it because it is waiting for a connection // acquire the global mutex and never unlock it because it is waiting for a connection
// which it will never obtain. // which it will never obtain.
d.db.SetMaxOpenConns(20) d.db.SetMaxOpenConns(20)
if err = d.statements.prepare(d.db); err != nil {
return nil, err
}
d.eventStateKeys, err = NewSqliteEventStateKeysTable(d.db) d.eventStateKeys, err = NewSqliteEventStateKeysTable(d.db)
if err != nil { if err != nil {
return nil, err return nil, err
@ -120,6 +118,10 @@ func Open(dataSourceName string) (*Database, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
d.membership, err = NewSqliteMembershipTable(d.db)
if err != nil {
return nil, err
}
d.Database = shared.Database{ d.Database = shared.Database{
DB: d.db, DB: d.db,
EventsTable: d.events, EventsTable: d.events,
@ -133,6 +135,7 @@ func Open(dataSourceName string) (*Database, error) {
PrevEventsTable: d.prevEvents, PrevEventsTable: d.prevEvents,
RoomAliasesTable: roomAliases, RoomAliasesTable: roomAliases,
InvitesTable: d.invites, InvitesTable: d.invites,
MembershipTable: d.membership,
} }
return &d, nil return &d, nil
} }
@ -364,7 +367,7 @@ type membershipUpdater struct {
d *Database d *Database
roomNID types.RoomNID roomNID types.RoomNID
targetUserNID types.EventStateKeyNID targetUserNID types.EventStateKeyNID
membership membershipState membership tables.MembershipState
} }
func (d *Database) membershipUpdaterTxn( func (d *Database) membershipUpdaterTxn(
@ -375,11 +378,11 @@ func (d *Database) membershipUpdaterTxn(
targetLocal bool, targetLocal bool,
) (types.MembershipUpdater, error) { ) (types.MembershipUpdater, error) {
if err := d.statements.insertMembership(ctx, txn, roomNID, targetUserNID, targetLocal); err != nil { if err := d.membership.InsertMembership(ctx, txn, roomNID, targetUserNID, targetLocal); err != nil {
return nil, err return nil, err
} }
membership, err := d.statements.selectMembershipForUpdate(ctx, txn, roomNID, targetUserNID) membership, err := d.membership.SelectMembershipForUpdate(ctx, txn, roomNID, targetUserNID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -392,17 +395,17 @@ func (d *Database) membershipUpdaterTxn(
// IsInvite implements types.MembershipUpdater // IsInvite implements types.MembershipUpdater
func (u *membershipUpdater) IsInvite() bool { func (u *membershipUpdater) IsInvite() bool {
return u.membership == membershipStateInvite return u.membership == tables.MembershipStateInvite
} }
// IsJoin implements types.MembershipUpdater // IsJoin implements types.MembershipUpdater
func (u *membershipUpdater) IsJoin() bool { func (u *membershipUpdater) IsJoin() bool {
return u.membership == membershipStateJoin return u.membership == tables.MembershipStateJoin
} }
// IsLeave implements types.MembershipUpdater // IsLeave implements types.MembershipUpdater
func (u *membershipUpdater) IsLeave() bool { func (u *membershipUpdater) IsLeave() bool {
return u.membership == membershipStateLeaveOrBan return u.membership == tables.MembershipStateLeaveOrBan
} }
// SetToInvite implements types.MembershipUpdater // SetToInvite implements types.MembershipUpdater
@ -418,9 +421,9 @@ func (u *membershipUpdater) SetToInvite(event gomatrixserverlib.Event) (inserted
if err != nil { if err != nil {
return err return err
} }
if u.membership != membershipStateInvite { if u.membership != tables.MembershipStateInvite {
if err = u.d.statements.updateMembership( if err = u.d.membership.UpdateMembership(
u.ctx, txn, u.roomNID, u.targetUserNID, senderUserNID, membershipStateInvite, 0, u.ctx, txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateInvite, 0,
); err != nil { ); err != nil {
return err return err
} }
@ -454,10 +457,10 @@ func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd
return err return err
} }
if u.membership != membershipStateJoin || isUpdate { if u.membership != tables.MembershipStateJoin || isUpdate {
if err = u.d.statements.updateMembership( if err = u.d.membership.UpdateMembership(
u.ctx, txn, u.roomNID, u.targetUserNID, senderUserNID, u.ctx, txn, u.roomNID, u.targetUserNID, senderUserNID,
membershipStateJoin, nIDs[eventID], tables.MembershipStateJoin, nIDs[eventID],
); err != nil { ); err != nil {
return err return err
} }
@ -488,10 +491,10 @@ func (u *membershipUpdater) SetToLeave(senderUserID string, eventID string) (inv
return err return err
} }
if u.membership != membershipStateLeaveOrBan { if u.membership != tables.MembershipStateLeaveOrBan {
if err = u.d.statements.updateMembership( if err = u.d.membership.UpdateMembership(
u.ctx, txn, u.roomNID, u.targetUserNID, senderUserNID, u.ctx, txn, u.roomNID, u.targetUserNID, senderUserNID,
membershipStateLeaveOrBan, nIDs[eventID], tables.MembershipStateLeaveOrBan, nIDs[eventID],
); err != nil { ); err != nil {
return err return err
} }
@ -512,8 +515,8 @@ func (d *Database) GetMembership(
} }
membershipEventNID, _, err = membershipEventNID, _, err =
d.statements.selectMembershipFromRoomAndTarget( d.membership.SelectMembershipFromRoomAndTarget(
ctx, txn, roomNID, requestSenderUserNID, ctx, roomNID, requestSenderUserNID,
) )
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
// The user has never been a member of that room // The user has never been a member of that room
@ -533,18 +536,14 @@ func (d *Database) GetMembership(
func (d *Database) GetMembershipEventNIDsForRoom( func (d *Database) GetMembershipEventNIDsForRoom(
ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool, ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool,
) (eventNIDs []types.EventNID, err error) { ) (eventNIDs []types.EventNID, err error) {
err = internal.WithTransaction(d.db, func(txn *sql.Tx) error { if joinOnly {
if joinOnly { eventNIDs, err = d.membership.SelectMembershipsFromRoomAndMembership(
eventNIDs, err = d.statements.selectMembershipsFromRoomAndMembership( ctx, roomNID, tables.MembershipStateJoin, localOnly,
ctx, txn, roomNID, membershipStateJoin, localOnly, )
) return
return nil }
}
eventNIDs, err = d.statements.selectMembershipsFromRoom(ctx, txn, roomNID, localOnly) return d.membership.SelectMembershipsFromRoom(ctx, roomNID, localOnly)
return nil
})
return
} }
type transaction struct { type transaction struct {

View file

@ -103,3 +103,20 @@ type Invites interface {
// SelectInviteActiveForUserInRoom returns a list of sender state key NIDs // SelectInviteActiveForUserInRoom returns a list of sender state key NIDs
SelectInviteActiveForUserInRoom(ctx context.Context, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID) ([]types.EventStateKeyNID, error) SelectInviteActiveForUserInRoom(ctx context.Context, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID) ([]types.EventStateKeyNID, error)
} }
type MembershipState int64
const (
MembershipStateLeaveOrBan MembershipState = 1
MembershipStateInvite MembershipState = 2
MembershipStateJoin MembershipState = 3
)
type Membership interface {
InsertMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, localTarget bool) error
SelectMembershipForUpdate(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (MembershipState, error)
SelectMembershipFromRoomAndTarget(ctx context.Context, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (types.EventNID, MembershipState, error)
SelectMembershipsFromRoom(ctx context.Context, roomNID types.RoomNID, localOnly bool) (eventNIDs []types.EventNID, err error)
SelectMembershipsFromRoomAndMembership(ctx context.Context, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error)
UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID) error
}