Keep track of membership in Client API (#159)

* Saving memberships

* Removed unused index

* Removed useless log

* Fixed membership not being saved on the right conditions + added membership removal

* Updated outdated comment

* Use server lib method + check server name + use new roomserver API

* Better handling of events from the room server

* Fixed membership removal

* Corrected indentation

* Fix tests (hopefully)

* Replace broken kafka mirror

* Apply requested changes on database management

* Remove useless check and function

* Moved memberships update to the database package

* Use new common function

* Remove useless function
This commit is contained in:
Brendan Abolivier 2017-07-17 18:10:56 +01:00 committed by Mark Haines
parent b06d1124f7
commit d9b8e5de45
8 changed files with 345 additions and 22 deletions

View file

@ -0,0 +1,85 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package accounts
import (
"database/sql"
"github.com/lib/pq"
)
const membershipSchema = `
-- Stores data about users memberships to rooms.
CREATE TABLE IF NOT EXISTS memberships (
-- The Matrix user ID localpart for the member
localpart TEXT NOT NULL,
-- The room this user is a member of
room_id TEXT NOT NULL,
-- The ID of the join membership event
event_id TEXT NOT NULL,
-- A user can only be member of a room once
PRIMARY KEY (localpart, room_id)
);
-- Use index to process deletion by ID more efficiently
CREATE UNIQUE INDEX IF NOT EXISTS membership_event_id ON memberships(event_id);
`
const insertMembershipSQL = "" +
"INSERT INTO memberships(localpart, room_id, event_id) VALUES ($1, $2, $3)"
const selectMembershipSQL = "" +
"SELECT * from memberships WHERE localpart = $1 AND room_id = $2"
const selectMembershipsByLocalpartSQL = "" +
"SELECT room_id FROM memberships WHERE localpart = $1"
const deleteMembershipsByEventIDsSQL = "" +
"DELETE FROM memberships WHERE event_id = ANY($1)"
type membershipStatements struct {
deleteMembershipsByEventIDsStmt *sql.Stmt
insertMembershipStmt *sql.Stmt
selectMembershipByEventIDStmt *sql.Stmt
selectMembershipsByLocalpartStmt *sql.Stmt
}
func (s *membershipStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(membershipSchema)
if err != nil {
return
}
if s.deleteMembershipsByEventIDsStmt, err = db.Prepare(deleteMembershipsByEventIDsSQL); err != nil {
return
}
if s.insertMembershipStmt, err = db.Prepare(insertMembershipSQL); err != nil {
return
}
if s.selectMembershipsByLocalpartStmt, err = db.Prepare(selectMembershipsByLocalpartSQL); err != nil {
return
}
return
}
func (s *membershipStatements) insertMembership(localpart string, roomID string, eventID string, txn *sql.Tx) (err error) {
_, err = txn.Stmt(s.insertMembershipStmt).Exec(localpart, roomID, eventID)
return
}
func (s *membershipStatements) deleteMembershipsByEventIDs(eventIDs []string, txn *sql.Tx) (err error) {
_, err = txn.Stmt(s.deleteMembershipsByEventIDsStmt).Exec(pq.StringArray(eventIDs))
return
}

View file

@ -18,6 +18,7 @@ import (
"database/sql" "database/sql"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
// Import the postgres database driver. // Import the postgres database driver.
@ -27,8 +28,11 @@ import (
// Database represents an account database // Database represents an account database
type Database struct { type Database struct {
db *sql.DB db *sql.DB
partitions common.PartitionOffsetStatements
accounts accountsStatements accounts accountsStatements
profiles profilesStatements profiles profilesStatements
memberships membershipStatements
serverName gomatrixserverlib.ServerName
} }
// NewDatabase creates a new accounts and profiles database // NewDatabase creates a new accounts and profiles database
@ -38,6 +42,10 @@ func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName)
if db, err = sql.Open("postgres", dataSourceName); err != nil { if db, err = sql.Open("postgres", dataSourceName); err != nil {
return nil, err return nil, err
} }
partitions := common.PartitionOffsetStatements{}
if err = partitions.Prepare(db); err != nil {
return nil, err
}
a := accountsStatements{} a := accountsStatements{}
if err = a.prepare(db, serverName); err != nil { if err = a.prepare(db, serverName); err != nil {
return nil, err return nil, err
@ -46,7 +54,11 @@ func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName)
if err = p.prepare(db); err != nil { if err = p.prepare(db); err != nil {
return nil, err return nil, err
} }
return &Database{db, a, p}, nil m := membershipStatements{}
if err = m.prepare(db); err != nil {
return nil, err
}
return &Database{db, partitions, a, p, m, serverName}, nil
} }
// GetAccountByPassword returns the account associated with the given localpart and password. // GetAccountByPassword returns the account associated with the given localpart and password.
@ -93,6 +105,85 @@ func (d *Database) CreateAccount(localpart, plaintextPassword string) (*authtype
return d.accounts.insertAccount(localpart, hash) return d.accounts.insertAccount(localpart, hash)
} }
// PartitionOffsets implements common.PartitionStorer
func (d *Database) PartitionOffsets(topic string) ([]common.PartitionOffset, error) {
return d.partitions.SelectPartitionOffsets(topic)
}
// SetPartitionOffset implements common.PartitionStorer
func (d *Database) SetPartitionOffset(topic string, partition int32, offset int64) error {
return d.partitions.UpsertPartitionOffset(topic, partition, offset)
}
// SaveMembership saves the user matching a given localpart as a member of a given
// room. It also stores the ID of the `join` membership event.
// If a membership already exists between the user and the room, or of the
// insert fails, returns the SQL error
func (d *Database) SaveMembership(localpart string, roomID string, eventID string, txn *sql.Tx) error {
return d.memberships.insertMembership(localpart, roomID, eventID, txn)
}
// removeMembershipsByEventIDs removes the memberships of which the `join` membership
// event ID is included in a given array of events IDs
// If the removal fails, or if there is no membership to remove, returns an error
func (d *Database) removeMembershipsByEventIDs(eventIDs []string, txn *sql.Tx) error {
return d.memberships.deleteMembershipsByEventIDs(eventIDs, txn)
}
// UpdateMemberships adds the "join" membership events included in a given state
// events array, and removes those which ID is included in a given array of events
// IDs. All of the process is run in a transaction, which commits only once/if every
// insertion and deletion has been successfully processed.
// Returns a SQL error if there was an issue with any part of the process
func (d *Database) UpdateMemberships(eventsToAdd []gomatrixserverlib.Event, idsToRemove []string) error {
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
if err := d.removeMembershipsByEventIDs(idsToRemove, txn); err != nil {
return err
}
for _, event := range eventsToAdd {
if err := d.newMembership(event, txn); err != nil {
return err
}
}
return nil
})
}
// newMembership will save a new membership in the database if the given state
// event is a "join" membership event
// If the event isn't a "join" membership event, does nothing
// If an error occurred, returns it
func (d *Database) newMembership(ev gomatrixserverlib.Event, txn *sql.Tx) error {
if ev.Type() == "m.room.member" && ev.StateKey() != nil {
localpart, serverName, err := gomatrixserverlib.SplitID('@', *ev.StateKey())
if err != nil {
return err
}
// We only want state events from local users
if string(serverName) != string(d.serverName) {
return nil
}
eventID := ev.EventID()
roomID := ev.RoomID()
membership, err := ev.Membership()
if err != nil {
return err
}
// Only "join" membership events can be considered as new memberships
if membership == "join" {
if err := d.SaveMembership(localpart, roomID, eventID, txn); err != nil {
return err
}
}
}
return nil
}
func hashPassword(plaintext string) (hash string, err error) { func hashPassword(plaintext string) (hash string, err error) {
hashBytes, err := bcrypt.GenerateFromPassword([]byte(plaintext), bcrypt.DefaultCost) hashBytes, err := bcrypt.GenerateFromPassword([]byte(plaintext), bcrypt.DefaultCost)
return string(hashBytes), err return string(hashBytes), err

View file

@ -0,0 +1,141 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package consumers
import (
"encoding/json"
log "github.com/Sirupsen/logrus"
"github.com/matrix-org/dendrite/clientapi/auth/storage/accounts"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/common/config"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib"
sarama "gopkg.in/Shopify/sarama.v1"
)
// OutputRoomEvent consumes events that originated in the room server.
type OutputRoomEvent struct {
roomServerConsumer *common.ContinualConsumer
db *accounts.Database
query api.RoomserverQueryAPI
serverName string
}
// NewOutputRoomEvent creates a new OutputRoomEvent consumer. Call Start() to begin consuming from room servers.
func NewOutputRoomEvent(cfg *config.Dendrite, store *accounts.Database) (*OutputRoomEvent, error) {
kafkaConsumer, err := sarama.NewConsumer(cfg.Kafka.Addresses, nil)
if err != nil {
return nil, err
}
roomServerURL := cfg.RoomServerURL()
consumer := common.ContinualConsumer{
Topic: string(cfg.Kafka.Topics.OutputRoomEvent),
Consumer: kafkaConsumer,
PartitionStore: store,
}
s := &OutputRoomEvent{
roomServerConsumer: &consumer,
db: store,
query: api.NewRoomserverQueryAPIHTTP(roomServerURL, nil),
serverName: string(cfg.Matrix.ServerName),
}
consumer.ProcessMessage = s.onMessage
return s, nil
}
// Start consuming from room servers
func (s *OutputRoomEvent) Start() error {
return s.roomServerConsumer.Start()
}
// onMessage is called when the sync server receives a new event from the room server output log.
// It is not safe for this function to be called from multiple goroutines, or else the
// sync stream position may race and be incorrectly calculated.
func (s *OutputRoomEvent) onMessage(msg *sarama.ConsumerMessage) error {
// Parse out the event JSON
var output api.OutputEvent
if err := json.Unmarshal(msg.Value, &output); err != nil {
// If the message was invalid, log it and move on to the next message in the stream
log.WithError(err).Errorf("roomserver output log: message parse failure")
return nil
}
if output.Type != api.OutputTypeNewRoomEvent {
log.WithField("type", output.Type).Debug(
"roomserver output log: ignoring unknown output type",
)
return nil
}
ev := output.NewRoomEvent.Event
log.WithFields(log.Fields{
"event_id": ev.EventID(),
"room_id": ev.RoomID(),
"type": ev.Type(),
}).Info("received event from roomserver")
events, err := s.lookupStateEvents(output.NewRoomEvent.AddsStateEventIDs, ev)
if err != nil {
return err
}
if err := s.db.UpdateMemberships(events, output.NewRoomEvent.RemovesStateEventIDs); err != nil {
return err
}
return nil
}
// lookupStateEvents looks up the state events that are added by a new event.
func (s *OutputRoomEvent) lookupStateEvents(
addsStateEventIDs []string, event gomatrixserverlib.Event,
) ([]gomatrixserverlib.Event, error) {
// Fast path if there aren't any new state events.
if len(addsStateEventIDs) == 0 {
return nil, nil
}
// Fast path if the only state event added is the event itself.
if len(addsStateEventIDs) == 1 && addsStateEventIDs[0] == event.EventID() {
return []gomatrixserverlib.Event{event}, nil
}
result := []gomatrixserverlib.Event{}
missing := []string{}
for _, id := range addsStateEventIDs {
// Append the current event in the results if its ID is in the events list
if id == event.EventID() {
result = append(result, event)
} else {
// If the event isn't the current one, add it to the list of events
// to retrieve from the roomserver
missing = append(missing, id)
}
}
// Request the missing events from the roomserver
eventReq := api.QueryEventsByIDRequest{EventIDs: missing}
var eventResp api.QueryEventsByIDResponse
if err := s.query.QueryEventsByID(&eventReq, &eventResp); err != nil {
return nil, err
}
result = append(result, eventResp.Events...)
return result, nil
}

View file

@ -21,6 +21,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts"
"github.com/matrix-org/dendrite/clientapi/auth/storage/devices" "github.com/matrix-org/dendrite/clientapi/auth/storage/devices"
"github.com/matrix-org/dendrite/clientapi/consumers"
"github.com/matrix-org/dendrite/clientapi/producers" "github.com/matrix-org/dendrite/clientapi/producers"
"github.com/matrix-org/dendrite/clientapi/routing" "github.com/matrix-org/dendrite/clientapi/routing"
"github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/common"
@ -86,6 +87,14 @@ func main() {
KeyDatabase: keyDB, KeyDatabase: keyDB,
} }
consumer, err := consumers.NewOutputRoomEvent(cfg, accountDB)
if err != nil {
log.Panicf("startup: failed to create room server consumer: %s", err)
}
if err = consumer.Start(); err != nil {
log.Panicf("startup: failed to start room server consumer")
}
log.Info("Starting client API server on ", cfg.Listen.ClientAPI) log.Info("Starting client API server on ", cfg.Listen.ClientAPI)
routing.Setup( routing.Setup(
http.DefaultServeMux, http.DefaultClient, *cfg, roomserverProducer, http.DefaultServeMux, http.DefaultClient, *cfg, roomserverProducer,

View file

@ -104,14 +104,17 @@ func startMediaAPI(suffix string, dynamicThumbnails bool) (*exec.Cmd, chan error
proxyCmd, proxyCmdChan := test.StartProxy(proxyAddr, cfg) proxyCmd, proxyCmdChan := test.StartProxy(proxyAddr, cfg)
cmd, cmdChan := test.StartServer( test.InitDatabase(
serverType,
serverArgs,
postgresDatabase, postgresDatabase,
postgresContainerName, postgresContainerName,
databases, databases,
) )
cmd, cmdChan := test.CreateBackgroundCommand(
filepath.Join(filepath.Dir(os.Args[0]), "dendrite-"+serverType+"-server"),
serverArgs,
)
fmt.Printf("==TESTSERVER== STARTED %v -> %v : %v\n", proxyAddr, cfg.Listen.MediaAPI, dir) fmt.Printf("==TESTSERVER== STARTED %v -> %v : %v\n", proxyAddr, cfg.Listen.MediaAPI, dir)
return cmd, cmdChan, string(cfg.Listen.MediaAPI), proxyCmd, proxyCmdChan, proxyAddr, dir return cmd, cmdChan, string(cfg.Listen.MediaAPI), proxyCmd, proxyCmdChan, proxyAddr, dir
} }

View file

@ -147,9 +147,7 @@ func startSyncServer() (*exec.Cmd, chan error) {
testDatabaseName, testDatabaseName,
} }
cmd, cmdChan := test.StartServer( test.InitDatabase(
"sync-api",
serverArgs,
postgresDatabase, postgresDatabase,
postgresContainerName, postgresContainerName,
databases, databases,
@ -165,6 +163,11 @@ func startSyncServer() (*exec.Cmd, chan error) {
panic(err) panic(err)
} }
cmd, cmdChan := test.CreateBackgroundCommand(
filepath.Join(filepath.Dir(os.Args[0]), "dendrite-sync-api-server"),
serverArgs,
)
return cmd, cmdChan return cmd, cmdChan
} }

View file

@ -65,12 +65,8 @@ func CreateBackgroundCommand(command string, args []string) (*exec.Cmd, chan err
return cmd, cmdChan return cmd, cmdChan
} }
// StartServer creates the database and config file needed for the server to run and // InitDatabase creates the database and config file needed for the server to run
// then starts the server. The Cmd being executed is returned. A channel is also returned, func InitDatabase(postgresDatabase, postgresContainerName string, databases []string) {
// which will have any termination errors sent down it, followed immediately by the channel being closed.
// If postgresContainerName is not an empty string, psql will be run from inside that container. If it is
// an empty string, psql will be assumed to be in PATH.
func StartServer(serverType string, serverArgs []string, postgresDatabase, postgresContainerName string, databases []string) (*exec.Cmd, chan error) {
if len(databases) > 0 { if len(databases) > 0 {
var dbCmd string var dbCmd string
var dbArgs []string var dbArgs []string
@ -89,11 +85,6 @@ func StartServer(serverType string, serverArgs []string, postgresDatabase, postg
} }
} }
} }
return CreateBackgroundCommand(
filepath.Join(filepath.Dir(os.Args[0]), "dendrite-"+serverType+"-server"),
serverArgs,
)
} }
// StartProxy creates a reverse proxy // StartProxy creates a reverse proxy

View file

@ -5,7 +5,7 @@ set -eu
# The mirror to download kafka from is picked from the list of mirrors at # The mirror to download kafka from is picked from the list of mirrors at
# https://www.apache.org/dyn/closer.cgi?path=/kafka/0.10.2.0/kafka_2.11-0.10.2.0.tgz # https://www.apache.org/dyn/closer.cgi?path=/kafka/0.10.2.0/kafka_2.11-0.10.2.0.tgz
# TODO: Check the signature since we are downloading over HTTP. # TODO: Check the signature since we are downloading over HTTP.
MIRROR=http://mirror.ox.ac.uk/sites/rsync.apache.org/kafka/0.10.2.0/kafka_2.11-0.10.2.0.tgz MIRROR=http://apache.mirror.anlx.net/kafka/0.10.2.0/kafka_2.11-0.10.2.0.tgz
# Only download the kafka if it isn't already downloaded. # Only download the kafka if it isn't already downloaded.
test -f kafka.tgz || wget $MIRROR -O kafka.tgz test -f kafka.tgz || wget $MIRROR -O kafka.tgz