mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-11 17:03:10 -06:00
Moved memberships update to the database package
This commit is contained in:
parent
d81fa3d00d
commit
29deb18f9f
|
|
@ -32,6 +32,7 @@ type Database struct {
|
||||||
accounts accountsStatements
|
accounts accountsStatements
|
||||||
profiles profilesStatements
|
profiles profilesStatements
|
||||||
memberships membershipStatements
|
memberships membershipStatements
|
||||||
|
serverName gomatrixserverlib.ServerName
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDatabase creates a new accounts and profiles database
|
// NewDatabase creates a new accounts and profiles database
|
||||||
|
|
@ -57,7 +58,7 @@ func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName)
|
||||||
if err = m.prepare(db); err != nil {
|
if err = m.prepare(db); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &Database{db, partitions, a, p, m}, nil
|
return &Database{db, partitions, a, p, m, serverName}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountByPassword returns the account associated with the given localpart and password.
|
// GetAccountByPassword returns the account associated with the given localpart and password.
|
||||||
|
|
@ -122,36 +123,81 @@ func (d *Database) SaveMembership(localpart string, roomID string, eventID strin
|
||||||
return d.memberships.insertMembership(localpart, roomID, eventID, txn)
|
return d.memberships.insertMembership(localpart, roomID, eventID, txn)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveMembershipsByEventIDs removes the memberships of which the `join` membership
|
// removeMembershipsByEventIDs removes the memberships of which the `join` membership
|
||||||
// event ID is included in a given array of events IDs
|
// event ID is included in a given array of events IDs
|
||||||
// If the removal fails, or if there is no membership to remove, returns an error
|
// If the removal fails, or if there is no membership to remove, returns an error
|
||||||
func (d *Database) RemoveMembershipsByEventIDs(eventIDs []string, txn *sql.Tx) error {
|
func (d *Database) removeMembershipsByEventIDs(eventIDs []string, txn *sql.Tx) error {
|
||||||
return d.memberships.deleteMembershipsByEventIDs(eventIDs, txn)
|
return d.memberships.deleteMembershipsByEventIDs(eventIDs, txn)
|
||||||
}
|
}
|
||||||
|
|
||||||
// StartTransaction begins a new SQL transaction and returns it
|
// UpdateMemberships adds the "join" membership events included in a given state
|
||||||
// If there was an error during the transaction initialisation, returns it
|
// events array, and removes those which ID is included in a given array of events
|
||||||
func (d *Database) StartTransaction() (*sql.Tx, error) {
|
// IDs. All of the process is run in a transaction, which commits only once/if every
|
||||||
return d.db.Begin()
|
// insertion and deletion has been successfully processed.
|
||||||
}
|
// Returns a SQL error if there was an issue with any part of the process
|
||||||
|
func (d *Database) UpdateMemberships(eventsToAdd []gomatrixserverlib.Event, idsToRemove []string) error {
|
||||||
// EndTransation is called at the end of a transaction started with StartTransaction
|
txn, err := d.db.Begin()
|
||||||
// If called with an error, the transaction will rollback, if called with nil the
|
|
||||||
// transaction will commit
|
|
||||||
// If there was an error during either rollback or commit, returns it
|
|
||||||
func (d *Database) EndTransation(txn *sql.Tx, err error) error {
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return txn.Rollback()
|
if e := txn.Rollback(); e != nil {
|
||||||
|
return e
|
||||||
}
|
}
|
||||||
return txn.Commit()
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := d.removeMembershipsByEventIDs(idsToRemove, txn); err != nil {
|
||||||
|
if e := txn.Rollback(); e != nil {
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, event := range eventsToAdd {
|
||||||
|
if err := d.newMembership(event, txn); err != nil {
|
||||||
|
if e := txn.Rollback(); e != nil {
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := txn.Commit(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetMembershipByEventID returns the membership (as a user localpart and a room ID)
|
// newMembership will save a new membership in the database if the given state
|
||||||
// for which the `join` membership event ID matches a given event ID
|
// event is a "join" membership event
|
||||||
// If no membership match this event ID, the localpart and room ID will be empty strings
|
// If the event isn't a "join" membership event, does nothing
|
||||||
// If an error happens during the retrieval, returns the SQL error
|
// If an error occurred, returns it
|
||||||
func (d *Database) GetMembershipByEventID(eventID string) (string, string, error) {
|
func (d *Database) newMembership(ev gomatrixserverlib.Event, txn *sql.Tx) error {
|
||||||
return d.memberships.selectMembershipByEventID(eventID)
|
if ev.Type() == "m.room.member" && ev.StateKey() != nil {
|
||||||
|
localpart, serverName, err := gomatrixserverlib.SplitID('@', *ev.StateKey())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// We only want state events from local users
|
||||||
|
if string(serverName) != string(d.serverName) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
eventID := ev.EventID()
|
||||||
|
roomID := ev.RoomID()
|
||||||
|
membership, err := ev.Membership()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only "join" membership events can be considered as new memberships
|
||||||
|
if membership == "join" {
|
||||||
|
if err := d.SaveMembership(localpart, roomID, eventID, txn); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func hashPassword(plaintext string) (hash string, err error) {
|
func hashPassword(plaintext string) (hash string, err error) {
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,6 @@
|
||||||
package consumers
|
package consumers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
|
||||||
log "github.com/Sirupsen/logrus"
|
log "github.com/Sirupsen/logrus"
|
||||||
|
|
@ -95,25 +94,7 @@ func (s *OutputRoomEvent) onMessage(msg *sarama.ConsumerMessage) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
txn, err := s.db.StartTransaction()
|
if err := s.db.UpdateMemberships(events, output.NewRoomEvent.RemovesStateEventIDs); err != nil {
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := s.db.RemoveMembershipsByEventIDs(output.NewRoomEvent.RemovesStateEventIDs, txn); err != nil {
|
|
||||||
if e := s.db.EndTransation(txn, err); e != nil {
|
|
||||||
return e
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, event := range events {
|
|
||||||
if err := s.updateMembership(event, txn); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := s.db.EndTransation(txn, nil); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -158,34 +139,3 @@ func (s *OutputRoomEvent) lookupStateEvents(
|
||||||
|
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OutputRoomEvent) updateMembership(ev gomatrixserverlib.Event, txn *sql.Tx) error {
|
|
||||||
if ev.Type() == "m.room.member" && ev.StateKey() != nil {
|
|
||||||
localpart, serverName, err := gomatrixserverlib.SplitID('@', *ev.StateKey())
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// we only want state events from local users
|
|
||||||
if string(serverName) != s.serverName {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
eventID := ev.EventID()
|
|
||||||
roomID := ev.RoomID()
|
|
||||||
membership, err := ev.Membership()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if membership == "join" {
|
|
||||||
if err := s.db.SaveMembership(localpart, roomID, eventID, txn); err != nil {
|
|
||||||
if e := s.db.EndTransation(txn, err); e != nil {
|
|
||||||
return e
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue