Moved memberships update to the database package

This commit is contained in:
Brendan Abolivier 2017-07-17 17:44:24 +01:00
parent d81fa3d00d
commit 29deb18f9f
No known key found for this signature in database
GPG key ID: 8EF1500759F70623
2 changed files with 69 additions and 73 deletions

View file

@ -32,6 +32,7 @@ type Database struct {
accounts accountsStatements
profiles profilesStatements
memberships membershipStatements
serverName gomatrixserverlib.ServerName
}
// NewDatabase creates a new accounts and profiles database
@ -57,7 +58,7 @@ func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName)
if err = m.prepare(db); err != nil {
return nil, err
}
return &Database{db, partitions, a, p, m}, nil
return &Database{db, partitions, a, p, m, serverName}, nil
}
// GetAccountByPassword returns the account associated with the given localpart and password.
@ -122,36 +123,81 @@ func (d *Database) SaveMembership(localpart string, roomID string, eventID strin
return d.memberships.insertMembership(localpart, roomID, eventID, txn)
}
// RemoveMembershipsByEventIDs removes the memberships of which the `join` membership
// removeMembershipsByEventIDs removes the memberships of which the `join` membership
// event ID is included in a given array of events IDs
// If the removal fails, or if there is no membership to remove, returns an error
func (d *Database) RemoveMembershipsByEventIDs(eventIDs []string, txn *sql.Tx) error {
func (d *Database) removeMembershipsByEventIDs(eventIDs []string, txn *sql.Tx) error {
return d.memberships.deleteMembershipsByEventIDs(eventIDs, txn)
}
// StartTransaction begins a new SQL transaction and returns it
// If there was an error during the transaction initialisation, returns it
func (d *Database) StartTransaction() (*sql.Tx, error) {
return d.db.Begin()
}
// EndTransation is called at the end of a transaction started with StartTransaction
// If called with an error, the transaction will rollback, if called with nil the
// transaction will commit
// If there was an error during either rollback or commit, returns it
func (d *Database) EndTransation(txn *sql.Tx, err error) error {
// UpdateMemberships adds the "join" membership events included in a given state
// events array, and removes those which ID is included in a given array of events
// IDs. All of the process is run in a transaction, which commits only once/if every
// insertion and deletion has been successfully processed.
// Returns a SQL error if there was an issue with any part of the process
func (d *Database) UpdateMemberships(eventsToAdd []gomatrixserverlib.Event, idsToRemove []string) error {
txn, err := d.db.Begin()
if err != nil {
return txn.Rollback()
if e := txn.Rollback(); e != nil {
return e
}
return err
}
return txn.Commit()
if err := d.removeMembershipsByEventIDs(idsToRemove, txn); err != nil {
if e := txn.Rollback(); e != nil {
return e
}
return err
}
for _, event := range eventsToAdd {
if err := d.newMembership(event, txn); err != nil {
if e := txn.Rollback(); e != nil {
return e
}
return err
}
}
if err := txn.Commit(); err != nil {
return err
}
return nil
}
// GetMembershipByEventID returns the membership (as a user localpart and a room ID)
// for which the `join` membership event ID matches a given event ID
// If no membership match this event ID, the localpart and room ID will be empty strings
// If an error happens during the retrieval, returns the SQL error
func (d *Database) GetMembershipByEventID(eventID string) (string, string, error) {
return d.memberships.selectMembershipByEventID(eventID)
// newMembership will save a new membership in the database if the given state
// event is a "join" membership event
// If the event isn't a "join" membership event, does nothing
// If an error occurred, returns it
func (d *Database) newMembership(ev gomatrixserverlib.Event, txn *sql.Tx) error {
if ev.Type() == "m.room.member" && ev.StateKey() != nil {
localpart, serverName, err := gomatrixserverlib.SplitID('@', *ev.StateKey())
if err != nil {
return err
}
// We only want state events from local users
if string(serverName) != string(d.serverName) {
return nil
}
eventID := ev.EventID()
roomID := ev.RoomID()
membership, err := ev.Membership()
if err != nil {
return err
}
// Only "join" membership events can be considered as new memberships
if membership == "join" {
if err := d.SaveMembership(localpart, roomID, eventID, txn); err != nil {
return err
}
}
}
return nil
}
func hashPassword(plaintext string) (hash string, err error) {

View file

@ -15,7 +15,6 @@
package consumers
import (
"database/sql"
"encoding/json"
log "github.com/Sirupsen/logrus"
@ -95,25 +94,7 @@ func (s *OutputRoomEvent) onMessage(msg *sarama.ConsumerMessage) error {
return err
}
txn, err := s.db.StartTransaction()
if err != nil {
return err
}
if err := s.db.RemoveMembershipsByEventIDs(output.NewRoomEvent.RemovesStateEventIDs, txn); err != nil {
if e := s.db.EndTransation(txn, err); e != nil {
return e
}
return err
}
for _, event := range events {
if err := s.updateMembership(event, txn); err != nil {
return err
}
}
if err := s.db.EndTransation(txn, nil); err != nil {
if err := s.db.UpdateMemberships(events, output.NewRoomEvent.RemovesStateEventIDs); err != nil {
return err
}
@ -158,34 +139,3 @@ func (s *OutputRoomEvent) lookupStateEvents(
return result, nil
}
func (s *OutputRoomEvent) updateMembership(ev gomatrixserverlib.Event, txn *sql.Tx) error {
if ev.Type() == "m.room.member" && ev.StateKey() != nil {
localpart, serverName, err := gomatrixserverlib.SplitID('@', *ev.StateKey())
if err != nil {
return err
}
// we only want state events from local users
if string(serverName) != s.serverName {
return nil
}
eventID := ev.EventID()
roomID := ev.RoomID()
membership, err := ev.Membership()
if err != nil {
return err
}
if membership == "join" {
if err := s.db.SaveMembership(localpart, roomID, eventID, txn); err != nil {
if e := s.db.EndTransation(txn, err); e != nil {
return e
}
return err
}
}
}
return nil
}