mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-20 05:13:11 -06:00
Convert membership table
This commit is contained in:
parent
9bdbb79ccd
commit
eebef093dc
|
|
@ -20,17 +20,10 @@ import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
type membershipState int64
|
|
||||||
|
|
||||||
const (
|
|
||||||
membershipStateLeaveOrBan membershipState = 1
|
|
||||||
membershipStateInvite membershipState = 2
|
|
||||||
membershipStateJoin membershipState = 3
|
|
||||||
)
|
|
||||||
|
|
||||||
const membershipSchema = `
|
const membershipSchema = `
|
||||||
-- The membership table is used to coordinate updates between the invite table
|
-- The membership table is used to coordinate updates between the invite table
|
||||||
-- and the room state tables.
|
-- and the room state tables.
|
||||||
|
|
@ -115,13 +108,14 @@ type membershipStatements struct {
|
||||||
updateMembershipStmt *sql.Stmt
|
updateMembershipStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *membershipStatements) prepare(db *sql.DB) (err error) {
|
func NewPostgresMembershipTable(db *sql.DB) (tables.Membership, error) {
|
||||||
_, err = db.Exec(membershipSchema)
|
s := &membershipStatements{}
|
||||||
|
_, err := db.Exec(membershipSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return statementList{
|
return s, statementList{
|
||||||
{&s.insertMembershipStmt, insertMembershipSQL},
|
{&s.insertMembershipStmt, insertMembershipSQL},
|
||||||
{&s.selectMembershipForUpdateStmt, selectMembershipForUpdateSQL},
|
{&s.selectMembershipForUpdateStmt, selectMembershipForUpdateSQL},
|
||||||
{&s.selectMembershipFromRoomAndTargetStmt, selectMembershipFromRoomAndTargetSQL},
|
{&s.selectMembershipFromRoomAndTargetStmt, selectMembershipFromRoomAndTargetSQL},
|
||||||
|
|
@ -133,7 +127,7 @@ func (s *membershipStatements) prepare(db *sql.DB) (err error) {
|
||||||
}.prepare(db)
|
}.prepare(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *membershipStatements) insertMembership(
|
func (s *membershipStatements) InsertMembership(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||||
localTarget bool,
|
localTarget bool,
|
||||||
|
|
@ -143,27 +137,27 @@ func (s *membershipStatements) insertMembership(
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *membershipStatements) selectMembershipForUpdate(
|
func (s *membershipStatements) SelectMembershipForUpdate(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||||
) (membership membershipState, err error) {
|
) (membership tables.MembershipState, err error) {
|
||||||
err = internal.TxStmt(txn, s.selectMembershipForUpdateStmt).QueryRowContext(
|
err = internal.TxStmt(txn, s.selectMembershipForUpdateStmt).QueryRowContext(
|
||||||
ctx, roomNID, targetUserNID,
|
ctx, roomNID, targetUserNID,
|
||||||
).Scan(&membership)
|
).Scan(&membership)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *membershipStatements) selectMembershipFromRoomAndTarget(
|
func (s *membershipStatements) SelectMembershipFromRoomAndTarget(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||||
) (eventNID types.EventNID, membership membershipState, err error) {
|
) (eventNID types.EventNID, membership tables.MembershipState, err error) {
|
||||||
err = s.selectMembershipFromRoomAndTargetStmt.QueryRowContext(
|
err = s.selectMembershipFromRoomAndTargetStmt.QueryRowContext(
|
||||||
ctx, roomNID, targetUserNID,
|
ctx, roomNID, targetUserNID,
|
||||||
).Scan(&membership, &eventNID)
|
).Scan(&membership, &eventNID)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *membershipStatements) selectMembershipsFromRoom(
|
func (s *membershipStatements) SelectMembershipsFromRoom(
|
||||||
ctx context.Context, roomNID types.RoomNID, localOnly bool,
|
ctx context.Context, roomNID types.RoomNID, localOnly bool,
|
||||||
) (eventNIDs []types.EventNID, err error) {
|
) (eventNIDs []types.EventNID, err error) {
|
||||||
var stmt *sql.Stmt
|
var stmt *sql.Stmt
|
||||||
|
|
@ -188,9 +182,9 @@ func (s *membershipStatements) selectMembershipsFromRoom(
|
||||||
return eventNIDs, rows.Err()
|
return eventNIDs, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *membershipStatements) selectMembershipsFromRoomAndMembership(
|
func (s *membershipStatements) SelectMembershipsFromRoomAndMembership(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
roomNID types.RoomNID, membership membershipState, localOnly bool,
|
roomNID types.RoomNID, membership tables.MembershipState, localOnly bool,
|
||||||
) (eventNIDs []types.EventNID, err error) {
|
) (eventNIDs []types.EventNID, err error) {
|
||||||
var rows *sql.Rows
|
var rows *sql.Rows
|
||||||
var stmt *sql.Stmt
|
var stmt *sql.Stmt
|
||||||
|
|
@ -215,10 +209,10 @@ func (s *membershipStatements) selectMembershipsFromRoomAndMembership(
|
||||||
return eventNIDs, rows.Err()
|
return eventNIDs, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *membershipStatements) updateMembership(
|
func (s *membershipStatements) UpdateMembership(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||||
senderUserNID types.EventStateKeyNID, membership membershipState,
|
senderUserNID types.EventStateKeyNID, membership tables.MembershipState,
|
||||||
eventNID types.EventNID,
|
eventNID types.EventNID,
|
||||||
) error {
|
) error {
|
||||||
_, err := internal.TxStmt(txn, s.updateMembershipStmt).ExecContext(
|
_, err := internal.TxStmt(txn, s.updateMembershipStmt).ExecContext(
|
||||||
|
|
|
||||||
|
|
@ -1,49 +0,0 @@
|
||||||
// Copyright 2017-2018 New Vector Ltd
|
|
||||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package postgres
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql"
|
|
||||||
)
|
|
||||||
|
|
||||||
type statements struct {
|
|
||||||
eventTypeStatements
|
|
||||||
eventStateKeyStatements
|
|
||||||
roomStatements
|
|
||||||
eventStatements
|
|
||||||
eventJSONStatements
|
|
||||||
stateSnapshotStatements
|
|
||||||
stateBlockStatements
|
|
||||||
previousEventStatements
|
|
||||||
roomAliasesStatements
|
|
||||||
inviteStatements
|
|
||||||
membershipStatements
|
|
||||||
transactionStatements
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *statements) prepare(db *sql.DB) error {
|
|
||||||
var err error
|
|
||||||
|
|
||||||
for _, prepare := range []func(db *sql.DB) error{
|
|
||||||
s.membershipStatements.prepare,
|
|
||||||
} {
|
|
||||||
if err = prepare(db); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
@ -33,7 +33,6 @@ import (
|
||||||
// A Database is used to store room events and stream offsets.
|
// A Database is used to store room events and stream offsets.
|
||||||
type Database struct {
|
type Database struct {
|
||||||
shared.Database
|
shared.Database
|
||||||
statements statements
|
|
||||||
events tables.Events
|
events tables.Events
|
||||||
eventTypes tables.EventTypes
|
eventTypes tables.EventTypes
|
||||||
eventStateKeys tables.EventStateKeys
|
eventStateKeys tables.EventStateKeys
|
||||||
|
|
@ -42,6 +41,7 @@ type Database struct {
|
||||||
transactions tables.Transactions
|
transactions tables.Transactions
|
||||||
prevEvents tables.PreviousEvents
|
prevEvents tables.PreviousEvents
|
||||||
invites tables.Invites
|
invites tables.Invites
|
||||||
|
membership tables.Membership
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -53,9 +53,6 @@ func Open(dataSourceName string, dbProperties internal.DbProperties) (*Database,
|
||||||
if d.db, err = sqlutil.Open("postgres", dataSourceName, dbProperties); err != nil {
|
if d.db, err = sqlutil.Open("postgres", dataSourceName, dbProperties); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if err = d.statements.prepare(d.db); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
d.eventStateKeys, err = NewPostgresEventStateKeysTable(d.db)
|
d.eventStateKeys, err = NewPostgresEventStateKeysTable(d.db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
@ -100,6 +97,10 @@ func Open(dataSourceName string, dbProperties internal.DbProperties) (*Database,
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
d.membership, err = NewPostgresMembershipTable(d.db)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
d.Database = shared.Database{
|
d.Database = shared.Database{
|
||||||
DB: d.db,
|
DB: d.db,
|
||||||
EventTypesTable: d.eventTypes,
|
EventTypesTable: d.eventTypes,
|
||||||
|
|
@ -113,6 +114,7 @@ func Open(dataSourceName string, dbProperties internal.DbProperties) (*Database,
|
||||||
PrevEventsTable: d.prevEvents,
|
PrevEventsTable: d.prevEvents,
|
||||||
RoomAliasesTable: roomAliases,
|
RoomAliasesTable: roomAliases,
|
||||||
InvitesTable: d.invites,
|
InvitesTable: d.invites,
|
||||||
|
MembershipTable: d.membership,
|
||||||
}
|
}
|
||||||
return &d, nil
|
return &d, nil
|
||||||
}
|
}
|
||||||
|
|
@ -300,7 +302,7 @@ type membershipUpdater struct {
|
||||||
d *Database
|
d *Database
|
||||||
roomNID types.RoomNID
|
roomNID types.RoomNID
|
||||||
targetUserNID types.EventStateKeyNID
|
targetUserNID types.EventStateKeyNID
|
||||||
membership membershipState
|
membership tables.MembershipState
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) membershipUpdaterTxn(
|
func (d *Database) membershipUpdaterTxn(
|
||||||
|
|
@ -311,11 +313,11 @@ func (d *Database) membershipUpdaterTxn(
|
||||||
targetLocal bool,
|
targetLocal bool,
|
||||||
) (types.MembershipUpdater, error) {
|
) (types.MembershipUpdater, error) {
|
||||||
|
|
||||||
if err := d.statements.insertMembership(ctx, txn, roomNID, targetUserNID, targetLocal); err != nil {
|
if err := d.membership.InsertMembership(ctx, txn, roomNID, targetUserNID, targetLocal); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
membership, err := d.statements.selectMembershipForUpdate(ctx, txn, roomNID, targetUserNID)
|
membership, err := d.membership.SelectMembershipForUpdate(ctx, txn, roomNID, targetUserNID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
@ -327,17 +329,17 @@ func (d *Database) membershipUpdaterTxn(
|
||||||
|
|
||||||
// IsInvite implements types.MembershipUpdater
|
// IsInvite implements types.MembershipUpdater
|
||||||
func (u *membershipUpdater) IsInvite() bool {
|
func (u *membershipUpdater) IsInvite() bool {
|
||||||
return u.membership == membershipStateInvite
|
return u.membership == tables.MembershipStateInvite
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsJoin implements types.MembershipUpdater
|
// IsJoin implements types.MembershipUpdater
|
||||||
func (u *membershipUpdater) IsJoin() bool {
|
func (u *membershipUpdater) IsJoin() bool {
|
||||||
return u.membership == membershipStateJoin
|
return u.membership == tables.MembershipStateJoin
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsLeave implements types.MembershipUpdater
|
// IsLeave implements types.MembershipUpdater
|
||||||
func (u *membershipUpdater) IsLeave() bool {
|
func (u *membershipUpdater) IsLeave() bool {
|
||||||
return u.membership == membershipStateLeaveOrBan
|
return u.membership == tables.MembershipStateLeaveOrBan
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetToInvite implements types.MembershipUpdater
|
// SetToInvite implements types.MembershipUpdater
|
||||||
|
|
@ -352,9 +354,9 @@ func (u *membershipUpdater) SetToInvite(event gomatrixserverlib.Event) (bool, er
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
if u.membership != membershipStateInvite {
|
if u.membership != tables.MembershipStateInvite {
|
||||||
if err = u.d.statements.updateMembership(
|
if err = u.d.membership.UpdateMembership(
|
||||||
u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, membershipStateInvite, 0,
|
u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateInvite, 0,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
|
@ -387,10 +389,10 @@ func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if u.membership != membershipStateJoin || isUpdate {
|
if u.membership != tables.MembershipStateJoin || isUpdate {
|
||||||
if err = u.d.statements.updateMembership(
|
if err = u.d.membership.UpdateMembership(
|
||||||
u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID,
|
u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID,
|
||||||
membershipStateJoin, nIDs[eventID],
|
tables.MembershipStateJoin, nIDs[eventID],
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
@ -418,10 +420,10 @@ func (u *membershipUpdater) SetToLeave(senderUserID string, eventID string) ([]s
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if u.membership != membershipStateLeaveOrBan {
|
if u.membership != tables.MembershipStateLeaveOrBan {
|
||||||
if err = u.d.statements.updateMembership(
|
if err = u.d.membership.UpdateMembership(
|
||||||
u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID,
|
u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID,
|
||||||
membershipStateLeaveOrBan, nIDs[eventID],
|
tables.MembershipStateLeaveOrBan, nIDs[eventID],
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
@ -439,7 +441,7 @@ func (d *Database) GetMembership(
|
||||||
}
|
}
|
||||||
|
|
||||||
senderMembershipEventNID, senderMembership, err :=
|
senderMembershipEventNID, senderMembership, err :=
|
||||||
d.statements.selectMembershipFromRoomAndTarget(
|
d.membership.SelectMembershipFromRoomAndTarget(
|
||||||
ctx, roomNID, requestSenderUserNID,
|
ctx, roomNID, requestSenderUserNID,
|
||||||
)
|
)
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
|
|
@ -449,7 +451,7 @@ func (d *Database) GetMembership(
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
return senderMembershipEventNID, senderMembership == membershipStateJoin, nil
|
return senderMembershipEventNID, senderMembership == tables.MembershipStateJoin, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetMembershipEventNIDsForRoom implements query.RoomserverQueryAPIDB
|
// GetMembershipEventNIDsForRoom implements query.RoomserverQueryAPIDB
|
||||||
|
|
@ -457,12 +459,12 @@ func (d *Database) GetMembershipEventNIDsForRoom(
|
||||||
ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool,
|
ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool,
|
||||||
) ([]types.EventNID, error) {
|
) ([]types.EventNID, error) {
|
||||||
if joinOnly {
|
if joinOnly {
|
||||||
return d.statements.selectMembershipsFromRoomAndMembership(
|
return d.membership.SelectMembershipsFromRoomAndMembership(
|
||||||
ctx, roomNID, membershipStateJoin, localOnly,
|
ctx, roomNID, tables.MembershipStateJoin, localOnly,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
return d.statements.selectMembershipsFromRoom(ctx, roomNID, localOnly)
|
return d.membership.SelectMembershipsFromRoom(ctx, roomNID, localOnly)
|
||||||
}
|
}
|
||||||
|
|
||||||
type transaction struct {
|
type transaction struct {
|
||||||
|
|
|
||||||
|
|
@ -25,6 +25,7 @@ type Database struct {
|
||||||
RoomAliasesTable tables.RoomAliases
|
RoomAliasesTable tables.RoomAliases
|
||||||
PrevEventsTable tables.PreviousEvents
|
PrevEventsTable tables.PreviousEvents
|
||||||
InvitesTable tables.Invites
|
InvitesTable tables.Invites
|
||||||
|
MembershipTable tables.Membership
|
||||||
}
|
}
|
||||||
|
|
||||||
// EventTypeNIDs implements state.RoomStateDatabase
|
// EventTypeNIDs implements state.RoomStateDatabase
|
||||||
|
|
|
||||||
|
|
@ -20,17 +20,10 @@ import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
type membershipState int64
|
|
||||||
|
|
||||||
const (
|
|
||||||
membershipStateLeaveOrBan membershipState = 1
|
|
||||||
membershipStateInvite membershipState = 2
|
|
||||||
membershipStateJoin membershipState = 3
|
|
||||||
)
|
|
||||||
|
|
||||||
const membershipSchema = `
|
const membershipSchema = `
|
||||||
CREATE TABLE IF NOT EXISTS roomserver_membership (
|
CREATE TABLE IF NOT EXISTS roomserver_membership (
|
||||||
room_nid INTEGER NOT NULL,
|
room_nid INTEGER NOT NULL,
|
||||||
|
|
@ -91,13 +84,14 @@ type membershipStatements struct {
|
||||||
updateMembershipStmt *sql.Stmt
|
updateMembershipStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *membershipStatements) prepare(db *sql.DB) (err error) {
|
func NewSqliteMembershipTable(db *sql.DB) (tables.Membership, error) {
|
||||||
_, err = db.Exec(membershipSchema)
|
s := &membershipStatements{}
|
||||||
|
_, err := db.Exec(membershipSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return statementList{
|
return s, statementList{
|
||||||
{&s.insertMembershipStmt, insertMembershipSQL},
|
{&s.insertMembershipStmt, insertMembershipSQL},
|
||||||
{&s.selectMembershipForUpdateStmt, selectMembershipForUpdateSQL},
|
{&s.selectMembershipForUpdateStmt, selectMembershipForUpdateSQL},
|
||||||
{&s.selectMembershipFromRoomAndTargetStmt, selectMembershipFromRoomAndTargetSQL},
|
{&s.selectMembershipFromRoomAndTargetStmt, selectMembershipFromRoomAndTargetSQL},
|
||||||
|
|
@ -109,7 +103,7 @@ func (s *membershipStatements) prepare(db *sql.DB) (err error) {
|
||||||
}.prepare(db)
|
}.prepare(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *membershipStatements) insertMembership(
|
func (s *membershipStatements) InsertMembership(
|
||||||
ctx context.Context, txn *sql.Tx,
|
ctx context.Context, txn *sql.Tx,
|
||||||
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||||
localTarget bool,
|
localTarget bool,
|
||||||
|
|
@ -119,10 +113,10 @@ func (s *membershipStatements) insertMembership(
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *membershipStatements) selectMembershipForUpdate(
|
func (s *membershipStatements) SelectMembershipForUpdate(
|
||||||
ctx context.Context, txn *sql.Tx,
|
ctx context.Context, txn *sql.Tx,
|
||||||
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||||
) (membership membershipState, err error) {
|
) (membership tables.MembershipState, err error) {
|
||||||
stmt := internal.TxStmt(txn, s.selectMembershipForUpdateStmt)
|
stmt := internal.TxStmt(txn, s.selectMembershipForUpdateStmt)
|
||||||
err = stmt.QueryRowContext(
|
err = stmt.QueryRowContext(
|
||||||
ctx, roomNID, targetUserNID,
|
ctx, roomNID, targetUserNID,
|
||||||
|
|
@ -130,26 +124,25 @@ func (s *membershipStatements) selectMembershipForUpdate(
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *membershipStatements) selectMembershipFromRoomAndTarget(
|
func (s *membershipStatements) SelectMembershipFromRoomAndTarget(
|
||||||
ctx context.Context, txn *sql.Tx,
|
ctx context.Context,
|
||||||
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||||
) (eventNID types.EventNID, membership membershipState, err error) {
|
) (eventNID types.EventNID, membership tables.MembershipState, err error) {
|
||||||
selectStmt := internal.TxStmt(txn, s.selectMembershipFromRoomAndTargetStmt)
|
err = s.selectMembershipFromRoomAndTargetStmt.QueryRowContext(
|
||||||
err = selectStmt.QueryRowContext(
|
|
||||||
ctx, roomNID, targetUserNID,
|
ctx, roomNID, targetUserNID,
|
||||||
).Scan(&membership, &eventNID)
|
).Scan(&membership, &eventNID)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *membershipStatements) selectMembershipsFromRoom(
|
func (s *membershipStatements) SelectMembershipsFromRoom(
|
||||||
ctx context.Context, txn *sql.Tx,
|
ctx context.Context,
|
||||||
roomNID types.RoomNID, localOnly bool,
|
roomNID types.RoomNID, localOnly bool,
|
||||||
) (eventNIDs []types.EventNID, err error) {
|
) (eventNIDs []types.EventNID, err error) {
|
||||||
var selectStmt *sql.Stmt
|
var selectStmt *sql.Stmt
|
||||||
if localOnly {
|
if localOnly {
|
||||||
selectStmt = internal.TxStmt(txn, s.selectLocalMembershipsFromRoomStmt)
|
selectStmt = s.selectLocalMembershipsFromRoomStmt
|
||||||
} else {
|
} else {
|
||||||
selectStmt = internal.TxStmt(txn, s.selectMembershipsFromRoomStmt)
|
selectStmt = s.selectMembershipsFromRoomStmt
|
||||||
}
|
}
|
||||||
rows, err := selectStmt.QueryContext(ctx, roomNID)
|
rows, err := selectStmt.QueryContext(ctx, roomNID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -167,15 +160,15 @@ func (s *membershipStatements) selectMembershipsFromRoom(
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *membershipStatements) selectMembershipsFromRoomAndMembership(
|
func (s *membershipStatements) SelectMembershipsFromRoomAndMembership(
|
||||||
ctx context.Context, txn *sql.Tx,
|
ctx context.Context,
|
||||||
roomNID types.RoomNID, membership membershipState, localOnly bool,
|
roomNID types.RoomNID, membership tables.MembershipState, localOnly bool,
|
||||||
) (eventNIDs []types.EventNID, err error) {
|
) (eventNIDs []types.EventNID, err error) {
|
||||||
var stmt *sql.Stmt
|
var stmt *sql.Stmt
|
||||||
if localOnly {
|
if localOnly {
|
||||||
stmt = internal.TxStmt(txn, s.selectLocalMembershipsFromRoomAndMembershipStmt)
|
stmt = s.selectLocalMembershipsFromRoomAndMembershipStmt
|
||||||
} else {
|
} else {
|
||||||
stmt = internal.TxStmt(txn, s.selectMembershipsFromRoomAndMembershipStmt)
|
stmt = s.selectMembershipsFromRoomAndMembershipStmt
|
||||||
}
|
}
|
||||||
rows, err := stmt.QueryContext(ctx, roomNID, membership)
|
rows, err := stmt.QueryContext(ctx, roomNID, membership)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -193,10 +186,10 @@ func (s *membershipStatements) selectMembershipsFromRoomAndMembership(
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *membershipStatements) updateMembership(
|
func (s *membershipStatements) UpdateMembership(
|
||||||
ctx context.Context, txn *sql.Tx,
|
ctx context.Context, txn *sql.Tx,
|
||||||
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||||
senderUserNID types.EventStateKeyNID, membership membershipState,
|
senderUserNID types.EventStateKeyNID, membership tables.MembershipState,
|
||||||
eventNID types.EventNID,
|
eventNID types.EventNID,
|
||||||
) error {
|
) error {
|
||||||
stmt := internal.TxStmt(txn, s.updateMembershipStmt)
|
stmt := internal.TxStmt(txn, s.updateMembershipStmt)
|
||||||
|
|
|
||||||
|
|
@ -1,49 +0,0 @@
|
||||||
// Copyright 2017-2018 New Vector Ltd
|
|
||||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package sqlite3
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql"
|
|
||||||
)
|
|
||||||
|
|
||||||
type statements struct {
|
|
||||||
eventTypeStatements
|
|
||||||
eventStateKeyStatements
|
|
||||||
roomStatements
|
|
||||||
eventStatements
|
|
||||||
eventJSONStatements
|
|
||||||
stateSnapshotStatements
|
|
||||||
stateBlockStatements
|
|
||||||
previousEventStatements
|
|
||||||
roomAliasesStatements
|
|
||||||
inviteStatements
|
|
||||||
membershipStatements
|
|
||||||
transactionStatements
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *statements) prepare(db *sql.DB) error {
|
|
||||||
var err error
|
|
||||||
|
|
||||||
for _, prepare := range []func(db *sql.DB) error{
|
|
||||||
s.membershipStatements.prepare,
|
|
||||||
} {
|
|
||||||
if err = prepare(db); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
@ -34,7 +34,6 @@ import (
|
||||||
// A Database is used to store room events and stream offsets.
|
// A Database is used to store room events and stream offsets.
|
||||||
type Database struct {
|
type Database struct {
|
||||||
shared.Database
|
shared.Database
|
||||||
statements statements
|
|
||||||
events tables.Events
|
events tables.Events
|
||||||
eventJSON tables.EventJSON
|
eventJSON tables.EventJSON
|
||||||
eventTypes tables.EventTypes
|
eventTypes tables.EventTypes
|
||||||
|
|
@ -43,6 +42,7 @@ type Database struct {
|
||||||
transactions tables.Transactions
|
transactions tables.Transactions
|
||||||
prevEvents tables.PreviousEvents
|
prevEvents tables.PreviousEvents
|
||||||
invites tables.Invites
|
invites tables.Invites
|
||||||
|
membership tables.Membership
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -73,9 +73,7 @@ func Open(dataSourceName string) (*Database, error) {
|
||||||
// acquire the global mutex and never unlock it because it is waiting for a connection
|
// acquire the global mutex and never unlock it because it is waiting for a connection
|
||||||
// which it will never obtain.
|
// which it will never obtain.
|
||||||
d.db.SetMaxOpenConns(20)
|
d.db.SetMaxOpenConns(20)
|
||||||
if err = d.statements.prepare(d.db); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
d.eventStateKeys, err = NewSqliteEventStateKeysTable(d.db)
|
d.eventStateKeys, err = NewSqliteEventStateKeysTable(d.db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
@ -120,6 +118,10 @@ func Open(dataSourceName string) (*Database, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
d.membership, err = NewSqliteMembershipTable(d.db)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
d.Database = shared.Database{
|
d.Database = shared.Database{
|
||||||
DB: d.db,
|
DB: d.db,
|
||||||
EventsTable: d.events,
|
EventsTable: d.events,
|
||||||
|
|
@ -133,6 +135,7 @@ func Open(dataSourceName string) (*Database, error) {
|
||||||
PrevEventsTable: d.prevEvents,
|
PrevEventsTable: d.prevEvents,
|
||||||
RoomAliasesTable: roomAliases,
|
RoomAliasesTable: roomAliases,
|
||||||
InvitesTable: d.invites,
|
InvitesTable: d.invites,
|
||||||
|
MembershipTable: d.membership,
|
||||||
}
|
}
|
||||||
return &d, nil
|
return &d, nil
|
||||||
}
|
}
|
||||||
|
|
@ -364,7 +367,7 @@ type membershipUpdater struct {
|
||||||
d *Database
|
d *Database
|
||||||
roomNID types.RoomNID
|
roomNID types.RoomNID
|
||||||
targetUserNID types.EventStateKeyNID
|
targetUserNID types.EventStateKeyNID
|
||||||
membership membershipState
|
membership tables.MembershipState
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) membershipUpdaterTxn(
|
func (d *Database) membershipUpdaterTxn(
|
||||||
|
|
@ -375,11 +378,11 @@ func (d *Database) membershipUpdaterTxn(
|
||||||
targetLocal bool,
|
targetLocal bool,
|
||||||
) (types.MembershipUpdater, error) {
|
) (types.MembershipUpdater, error) {
|
||||||
|
|
||||||
if err := d.statements.insertMembership(ctx, txn, roomNID, targetUserNID, targetLocal); err != nil {
|
if err := d.membership.InsertMembership(ctx, txn, roomNID, targetUserNID, targetLocal); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
membership, err := d.statements.selectMembershipForUpdate(ctx, txn, roomNID, targetUserNID)
|
membership, err := d.membership.SelectMembershipForUpdate(ctx, txn, roomNID, targetUserNID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
@ -392,17 +395,17 @@ func (d *Database) membershipUpdaterTxn(
|
||||||
|
|
||||||
// IsInvite implements types.MembershipUpdater
|
// IsInvite implements types.MembershipUpdater
|
||||||
func (u *membershipUpdater) IsInvite() bool {
|
func (u *membershipUpdater) IsInvite() bool {
|
||||||
return u.membership == membershipStateInvite
|
return u.membership == tables.MembershipStateInvite
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsJoin implements types.MembershipUpdater
|
// IsJoin implements types.MembershipUpdater
|
||||||
func (u *membershipUpdater) IsJoin() bool {
|
func (u *membershipUpdater) IsJoin() bool {
|
||||||
return u.membership == membershipStateJoin
|
return u.membership == tables.MembershipStateJoin
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsLeave implements types.MembershipUpdater
|
// IsLeave implements types.MembershipUpdater
|
||||||
func (u *membershipUpdater) IsLeave() bool {
|
func (u *membershipUpdater) IsLeave() bool {
|
||||||
return u.membership == membershipStateLeaveOrBan
|
return u.membership == tables.MembershipStateLeaveOrBan
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetToInvite implements types.MembershipUpdater
|
// SetToInvite implements types.MembershipUpdater
|
||||||
|
|
@ -418,9 +421,9 @@ func (u *membershipUpdater) SetToInvite(event gomatrixserverlib.Event) (inserted
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if u.membership != membershipStateInvite {
|
if u.membership != tables.MembershipStateInvite {
|
||||||
if err = u.d.statements.updateMembership(
|
if err = u.d.membership.UpdateMembership(
|
||||||
u.ctx, txn, u.roomNID, u.targetUserNID, senderUserNID, membershipStateInvite, 0,
|
u.ctx, txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateInvite, 0,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
@ -454,10 +457,10 @@ func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if u.membership != membershipStateJoin || isUpdate {
|
if u.membership != tables.MembershipStateJoin || isUpdate {
|
||||||
if err = u.d.statements.updateMembership(
|
if err = u.d.membership.UpdateMembership(
|
||||||
u.ctx, txn, u.roomNID, u.targetUserNID, senderUserNID,
|
u.ctx, txn, u.roomNID, u.targetUserNID, senderUserNID,
|
||||||
membershipStateJoin, nIDs[eventID],
|
tables.MembershipStateJoin, nIDs[eventID],
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
@ -488,10 +491,10 @@ func (u *membershipUpdater) SetToLeave(senderUserID string, eventID string) (inv
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if u.membership != membershipStateLeaveOrBan {
|
if u.membership != tables.MembershipStateLeaveOrBan {
|
||||||
if err = u.d.statements.updateMembership(
|
if err = u.d.membership.UpdateMembership(
|
||||||
u.ctx, txn, u.roomNID, u.targetUserNID, senderUserNID,
|
u.ctx, txn, u.roomNID, u.targetUserNID, senderUserNID,
|
||||||
membershipStateLeaveOrBan, nIDs[eventID],
|
tables.MembershipStateLeaveOrBan, nIDs[eventID],
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
@ -512,8 +515,8 @@ func (d *Database) GetMembership(
|
||||||
}
|
}
|
||||||
|
|
||||||
membershipEventNID, _, err =
|
membershipEventNID, _, err =
|
||||||
d.statements.selectMembershipFromRoomAndTarget(
|
d.membership.SelectMembershipFromRoomAndTarget(
|
||||||
ctx, txn, roomNID, requestSenderUserNID,
|
ctx, roomNID, requestSenderUserNID,
|
||||||
)
|
)
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
// The user has never been a member of that room
|
// The user has never been a member of that room
|
||||||
|
|
@ -533,18 +536,14 @@ func (d *Database) GetMembership(
|
||||||
func (d *Database) GetMembershipEventNIDsForRoom(
|
func (d *Database) GetMembershipEventNIDsForRoom(
|
||||||
ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool,
|
ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool,
|
||||||
) (eventNIDs []types.EventNID, err error) {
|
) (eventNIDs []types.EventNID, err error) {
|
||||||
err = internal.WithTransaction(d.db, func(txn *sql.Tx) error {
|
if joinOnly {
|
||||||
if joinOnly {
|
eventNIDs, err = d.membership.SelectMembershipsFromRoomAndMembership(
|
||||||
eventNIDs, err = d.statements.selectMembershipsFromRoomAndMembership(
|
ctx, roomNID, tables.MembershipStateJoin, localOnly,
|
||||||
ctx, txn, roomNID, membershipStateJoin, localOnly,
|
)
|
||||||
)
|
return
|
||||||
return nil
|
}
|
||||||
}
|
|
||||||
|
|
||||||
eventNIDs, err = d.statements.selectMembershipsFromRoom(ctx, txn, roomNID, localOnly)
|
return d.membership.SelectMembershipsFromRoom(ctx, roomNID, localOnly)
|
||||||
return nil
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type transaction struct {
|
type transaction struct {
|
||||||
|
|
|
||||||
|
|
@ -103,3 +103,20 @@ type Invites interface {
|
||||||
// SelectInviteActiveForUserInRoom returns a list of sender state key NIDs
|
// SelectInviteActiveForUserInRoom returns a list of sender state key NIDs
|
||||||
SelectInviteActiveForUserInRoom(ctx context.Context, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID) ([]types.EventStateKeyNID, error)
|
SelectInviteActiveForUserInRoom(ctx context.Context, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID) ([]types.EventStateKeyNID, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type MembershipState int64
|
||||||
|
|
||||||
|
const (
|
||||||
|
MembershipStateLeaveOrBan MembershipState = 1
|
||||||
|
MembershipStateInvite MembershipState = 2
|
||||||
|
MembershipStateJoin MembershipState = 3
|
||||||
|
)
|
||||||
|
|
||||||
|
type Membership interface {
|
||||||
|
InsertMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, localTarget bool) error
|
||||||
|
SelectMembershipForUpdate(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (MembershipState, error)
|
||||||
|
SelectMembershipFromRoomAndTarget(ctx context.Context, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (types.EventNID, MembershipState, error)
|
||||||
|
SelectMembershipsFromRoom(ctx context.Context, roomNID types.RoomNID, localOnly bool) (eventNIDs []types.EventNID, err error)
|
||||||
|
SelectMembershipsFromRoomAndMembership(ctx context.Context, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error)
|
||||||
|
UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID) error
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue