mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-10 16:33:11 -06:00
Moved memberships update to the database package
This commit is contained in:
parent
d81fa3d00d
commit
29deb18f9f
|
|
@ -32,6 +32,7 @@ type Database struct {
|
|||
accounts accountsStatements
|
||||
profiles profilesStatements
|
||||
memberships membershipStatements
|
||||
serverName gomatrixserverlib.ServerName
|
||||
}
|
||||
|
||||
// NewDatabase creates a new accounts and profiles database
|
||||
|
|
@ -57,7 +58,7 @@ func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName)
|
|||
if err = m.prepare(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Database{db, partitions, a, p, m}, nil
|
||||
return &Database{db, partitions, a, p, m, serverName}, nil
|
||||
}
|
||||
|
||||
// GetAccountByPassword returns the account associated with the given localpart and password.
|
||||
|
|
@ -122,36 +123,81 @@ func (d *Database) SaveMembership(localpart string, roomID string, eventID strin
|
|||
return d.memberships.insertMembership(localpart, roomID, eventID, txn)
|
||||
}
|
||||
|
||||
// RemoveMembershipsByEventIDs removes the memberships of which the `join` membership
|
||||
// removeMembershipsByEventIDs removes the memberships of which the `join` membership
|
||||
// event ID is included in a given array of events IDs
|
||||
// If the removal fails, or if there is no membership to remove, returns an error
|
||||
func (d *Database) RemoveMembershipsByEventIDs(eventIDs []string, txn *sql.Tx) error {
|
||||
func (d *Database) removeMembershipsByEventIDs(eventIDs []string, txn *sql.Tx) error {
|
||||
return d.memberships.deleteMembershipsByEventIDs(eventIDs, txn)
|
||||
}
|
||||
|
||||
// StartTransaction begins a new SQL transaction and returns it
|
||||
// If there was an error during the transaction initialisation, returns it
|
||||
func (d *Database) StartTransaction() (*sql.Tx, error) {
|
||||
return d.db.Begin()
|
||||
}
|
||||
|
||||
// EndTransation is called at the end of a transaction started with StartTransaction
|
||||
// If called with an error, the transaction will rollback, if called with nil the
|
||||
// transaction will commit
|
||||
// If there was an error during either rollback or commit, returns it
|
||||
func (d *Database) EndTransation(txn *sql.Tx, err error) error {
|
||||
// UpdateMemberships adds the "join" membership events included in a given state
|
||||
// events array, and removes those which ID is included in a given array of events
|
||||
// IDs. All of the process is run in a transaction, which commits only once/if every
|
||||
// insertion and deletion has been successfully processed.
|
||||
// Returns a SQL error if there was an issue with any part of the process
|
||||
func (d *Database) UpdateMemberships(eventsToAdd []gomatrixserverlib.Event, idsToRemove []string) error {
|
||||
txn, err := d.db.Begin()
|
||||
if err != nil {
|
||||
return txn.Rollback()
|
||||
if e := txn.Rollback(); e != nil {
|
||||
return e
|
||||
}
|
||||
return err
|
||||
}
|
||||
return txn.Commit()
|
||||
|
||||
if err := d.removeMembershipsByEventIDs(idsToRemove, txn); err != nil {
|
||||
if e := txn.Rollback(); e != nil {
|
||||
return e
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
for _, event := range eventsToAdd {
|
||||
if err := d.newMembership(event, txn); err != nil {
|
||||
if e := txn.Rollback(); e != nil {
|
||||
return e
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := txn.Commit(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetMembershipByEventID returns the membership (as a user localpart and a room ID)
|
||||
// for which the `join` membership event ID matches a given event ID
|
||||
// If no membership match this event ID, the localpart and room ID will be empty strings
|
||||
// If an error happens during the retrieval, returns the SQL error
|
||||
func (d *Database) GetMembershipByEventID(eventID string) (string, string, error) {
|
||||
return d.memberships.selectMembershipByEventID(eventID)
|
||||
// newMembership will save a new membership in the database if the given state
|
||||
// event is a "join" membership event
|
||||
// If the event isn't a "join" membership event, does nothing
|
||||
// If an error occurred, returns it
|
||||
func (d *Database) newMembership(ev gomatrixserverlib.Event, txn *sql.Tx) error {
|
||||
if ev.Type() == "m.room.member" && ev.StateKey() != nil {
|
||||
localpart, serverName, err := gomatrixserverlib.SplitID('@', *ev.StateKey())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// We only want state events from local users
|
||||
if string(serverName) != string(d.serverName) {
|
||||
return nil
|
||||
}
|
||||
|
||||
eventID := ev.EventID()
|
||||
roomID := ev.RoomID()
|
||||
membership, err := ev.Membership()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Only "join" membership events can be considered as new memberships
|
||||
if membership == "join" {
|
||||
if err := d.SaveMembership(localpart, roomID, eventID, txn); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func hashPassword(plaintext string) (hash string, err error) {
|
||||
|
|
|
|||
|
|
@ -15,7 +15,6 @@
|
|||
package consumers
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
|
||||
log "github.com/Sirupsen/logrus"
|
||||
|
|
@ -95,25 +94,7 @@ func (s *OutputRoomEvent) onMessage(msg *sarama.ConsumerMessage) error {
|
|||
return err
|
||||
}
|
||||
|
||||
txn, err := s.db.StartTransaction()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.db.RemoveMembershipsByEventIDs(output.NewRoomEvent.RemovesStateEventIDs, txn); err != nil {
|
||||
if e := s.db.EndTransation(txn, err); e != nil {
|
||||
return e
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
for _, event := range events {
|
||||
if err := s.updateMembership(event, txn); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.db.EndTransation(txn, nil); err != nil {
|
||||
if err := s.db.UpdateMemberships(events, output.NewRoomEvent.RemovesStateEventIDs); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
@ -158,34 +139,3 @@ func (s *OutputRoomEvent) lookupStateEvents(
|
|||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *OutputRoomEvent) updateMembership(ev gomatrixserverlib.Event, txn *sql.Tx) error {
|
||||
if ev.Type() == "m.room.member" && ev.StateKey() != nil {
|
||||
localpart, serverName, err := gomatrixserverlib.SplitID('@', *ev.StateKey())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// we only want state events from local users
|
||||
if string(serverName) != s.serverName {
|
||||
return nil
|
||||
}
|
||||
|
||||
eventID := ev.EventID()
|
||||
roomID := ev.RoomID()
|
||||
membership, err := ev.Membership()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if membership == "join" {
|
||||
if err := s.db.SaveMembership(localpart, roomID, eventID, txn); err != nil {
|
||||
if e := s.db.EndTransation(txn, err); e != nil {
|
||||
return e
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue