Apply requested changes on database management

This commit is contained in:
Brendan Abolivier 2017-07-17 13:06:27 +01:00
parent 0bf3a016f9
commit 3ecf3b49d1
No known key found for this signature in database
GPG key ID: 8EF1500759F70623
3 changed files with 62 additions and 32 deletions

View file

@ -16,6 +16,8 @@ package accounts
import (
"database/sql"
"github.com/lib/pq"
)
const membershipSchema = `
@ -31,6 +33,9 @@ CREATE TABLE IF NOT EXISTS memberships (
-- A user can only be member of a room once
PRIMARY KEY (localpart, room_id)
);
-- Use index to process deletion by ID more efficiently
CREATE UNIQUE INDEX IF NOT EXISTS membership_event_id ON memberships(event_id);
`
const insertMembershipSQL = "" +
@ -48,11 +53,11 @@ const selectMembershipsByLocalpartSQL = "" +
const deleteMembershipSQL = "" +
"DELETE FROM memberships WHERE localpart = $1 AND room_id = $2"
const deleteMembershipByEventIDSQL = "" +
"DELETE FROM memberships WHERE event_id = $1"
const deleteMembershipsByEventIDsSQL = "" +
"DELETE FROM memberships WHERE event_id = ANY($1)"
type membershipStatements struct {
deleteMembershipByEventIDStmt *sql.Stmt
deleteMembershipsByEventIDsStmt *sql.Stmt
deleteMembershipStmt *sql.Stmt
insertMembershipStmt *sql.Stmt
selectMembershipByEventIDStmt *sql.Stmt
@ -65,7 +70,7 @@ func (s *membershipStatements) prepare(db *sql.DB) (err error) {
if err != nil {
return
}
if s.deleteMembershipByEventIDStmt, err = db.Prepare(deleteMembershipByEventIDSQL); err != nil {
if s.deleteMembershipsByEventIDsStmt, err = db.Prepare(deleteMembershipsByEventIDsSQL); err != nil {
return
}
if s.deleteMembershipStmt, err = db.Prepare(deleteMembershipSQL); err != nil {
@ -86,18 +91,18 @@ func (s *membershipStatements) prepare(db *sql.DB) (err error) {
return
}
func (s *membershipStatements) insertMembership(localpart string, roomID string, eventID string) (err error) {
_, err = s.insertMembershipStmt.Exec(localpart, roomID, eventID)
func (s *membershipStatements) insertMembership(localpart string, roomID string, eventID string, txn *sql.Tx) (err error) {
_, err = txn.Stmt(s.insertMembershipStmt).Exec(localpart, roomID, eventID)
return
}
func (s *membershipStatements) deleteMembership(localpart string, roomID string) (err error) {
_, err = s.deleteMembershipStmt.Exec(localpart, roomID)
func (s *membershipStatements) deleteMembership(localpart string, roomID string, txn *sql.Tx) (err error) {
_, err = txn.Stmt(s.deleteMembershipStmt).Exec(localpart, roomID)
return
}
func (s *membershipStatements) deleteMembershipByEventID(eventID string) (err error) {
_, err = s.deleteMembershipByEventIDStmt.Exec(eventID)
func (s *membershipStatements) deleteMembershipsByEventIDs(eventIDs []string, txn *sql.Tx) (err error) {
_, err = txn.Stmt(s.deleteMembershipsByEventIDsStmt).Exec(pq.StringArray(eventIDs))
return
}

View file

@ -118,22 +118,39 @@ func (d *Database) SetPartitionOffset(topic string, partition int32, offset int6
// room. It also stores the ID of the `join` membership event.
// If a membership already exists between the user and the room, or of the
// insert fails, returns the SQL error
func (d *Database) SaveMembership(localpart string, roomID string, eventID string) error {
return d.memberships.insertMembership(localpart, roomID, eventID)
func (d *Database) SaveMembership(localpart string, roomID string, eventID string, txn *sql.Tx) error {
return d.memberships.insertMembership(localpart, roomID, eventID, txn)
}
// RemoveMembership removes the membership linking the user matching a given
// localpart and the room matching a given room ID.
// If the removal fails, or if there is no membership to remove, returns an error
func (d *Database) RemoveMembership(localpart string, roomID string) error {
return d.memberships.deleteMembership(localpart, roomID)
func (d *Database) RemoveMembership(localpart string, roomID string, txn *sql.Tx) error {
return d.memberships.deleteMembership(localpart, roomID, txn)
}
// RemoveMembershipByEventID removes the membership of which the `join` membership
// event ID matches a given event ID
// RemoveMembershipsByEventIDs removes the memberships of which the `join` membership
// event ID is included in a given array of events IDs
// If the removal fails, or if there is no membership to remove, returns an error
func (d *Database) RemoveMembershipByEventID(eventID string) error {
return d.memberships.deleteMembershipByEventID(eventID)
func (d *Database) RemoveMembershipsByEventIDs(eventIDs []string, txn *sql.Tx) error {
return d.memberships.deleteMembershipsByEventIDs(eventIDs, txn)
}
// StartTransaction begins a new SQL transaction and returns it
// If there was an error during the transaction initialisation, returns it
func (d *Database) StartTransaction() (*sql.Tx, error) {
return d.db.Begin()
}
// EndTransation is called at the end of a transaction started with StartTransaction
// If called with an error, the transaction will rollback, if called with nil the
// transaction will commit
// If there was an error during either rollback or commit, returns it
func (d *Database) EndTransation(txn *sql.Tx, err error) error {
if err != nil {
return txn.Rollback()
}
return txn.Commit()
}
// GetMembershipByEventID returns the membership (as a user localpart and a room ID)

View file

@ -15,6 +15,7 @@
package consumers
import (
"database/sql"
"encoding/json"
log "github.com/Sirupsen/logrus"
@ -94,16 +95,26 @@ func (s *OutputRoomEvent) onMessage(msg *sarama.ConsumerMessage) error {
return err
}
txn, err := s.db.StartTransaction()
if err != nil {
return err
}
if err := s.db.RemoveMembershipsByEventIDs(output.NewRoomEvent.RemovesStateEventIDs, txn); err != nil {
if e := s.db.EndTransation(txn, err); e != nil {
return e
}
return err
}
for _, event := range events {
if err := s.updateMembership(event); err != nil {
if err := s.updateMembership(event, txn); err != nil {
return err
}
}
for _, id := range output.NewRoomEvent.RemovesStateEventIDs {
if err := s.db.RemoveMembershipByEventID(id); err != nil {
return err
}
if err := s.db.EndTransation(txn, nil); err != nil {
return err
}
return nil
@ -157,7 +168,7 @@ func (s *OutputRoomEvent) lookupStateEvents(
return result, nil
}
func (s *OutputRoomEvent) updateMembership(ev gomatrixserverlib.Event) error {
func (s *OutputRoomEvent) updateMembership(ev gomatrixserverlib.Event, txn *sql.Tx) error {
if ev.Type() == "m.room.member" && ev.StateKey() != nil {
localpart, serverName, err := gomatrixserverlib.SplitID('@', *ev.StateKey())
if err != nil {
@ -176,14 +187,11 @@ func (s *OutputRoomEvent) updateMembership(ev gomatrixserverlib.Event) error {
return err
}
switch membership {
case "join":
if err := s.db.SaveMembership(localpart, roomID, eventID); err != nil {
return err
}
case "leave":
case "ban":
if err := s.db.RemoveMembership(localpart, roomID); err != nil {
if membership == "join" {
if err := s.db.SaveMembership(localpart, roomID, eventID, txn); err != nil {
if e := s.db.EndTransation(txn, err); e != nil {
return e
}
return err
}
}