Saving memberships

This commit is contained in:
Brendan Abolivier 2017-07-12 18:46:22 +01:00
parent 7d36ca03af
commit b89a6b0fb6
No known key found for this signature in database
GPG key ID: 8EF1500759F70623
4 changed files with 243 additions and 4 deletions

View file

@ -0,0 +1,86 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package accounts
import (
"database/sql"
"fmt"
)
const membershipSchema = `
-- Stores data about accounts profiles.
CREATE TABLE IF NOT EXISTS memberships (
-- The Matrix user ID localpart for the member
localpart TEXT NOT NULL,
-- The room this user is a member of
room_id TEXT NOT NULL,
PRIMARY KEY (localpart, room_id)
);
-- A user can only be member of a room once
CREATE UNIQUE INDEX IF NOT EXISTS localpart_id_idx ON memberships(localpart, room_id);
`
const insertMembershipSQL = "" +
"INSERT INTO memberships(localpart, room_id) VALUES ($1, $2)"
const selectMembershipSQL = "" +
"SELECT * from memberships WHERE localpart = $1 AND room_id = $2"
const selectMembershipsByLocalpartSQL = "" +
"SELECT room_id FROM memberships WHERE localpart = $1"
const deleteMembershipSQL = "" +
"DELETE FROM memberships WHERE localpart = $1 AND room_id = $2"
type membershipStatements struct {
deleteMembershipStmt *sql.Stmt
insertMembershipStmt *sql.Stmt
selectMembershipsByLocalpartStmt *sql.Stmt
selectMembershipStmt *sql.Stmt
}
func (s *membershipStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(membershipSchema)
if err != nil {
return
}
if s.deleteMembershipStmt, err = db.Prepare(deleteMembershipSQL); err != nil {
return
}
if s.insertMembershipStmt, err = db.Prepare(insertMembershipSQL); err != nil {
return
}
if s.selectMembershipsByLocalpartStmt, err = db.Prepare(selectMembershipsByLocalpartSQL); err != nil {
return
}
if s.selectMembershipStmt, err = db.Prepare(selectMembershipSQL); err != nil {
return
}
return
}
func (s *membershipStatements) insertMembership(localpart string, roomID string) (err error) {
fmt.Printf("Inserting membership for user %s and room %s\n", localpart, roomID)
_, err = s.insertMembershipStmt.Exec(localpart, roomID)
fmt.Println(err)
return
}
func (s *membershipStatements) deleteMembership(localpart string, roomID string) (err error) {
_, err = s.deleteMembershipStmt.Exec(localpart, roomID)
return
}

View file

@ -18,6 +18,7 @@ import (
"database/sql"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/gomatrixserverlib"
"golang.org/x/crypto/bcrypt"
// Import the postgres database driver.
@ -26,9 +27,11 @@ import (
// Database represents an account database
type Database struct {
db *sql.DB
accounts accountsStatements
profiles profilesStatements
db *sql.DB
partitions common.PartitionOffsetStatements
accounts accountsStatements
profiles profilesStatements
memberships membershipStatements
}
// NewDatabase creates a new accounts and profiles database
@ -38,6 +41,10 @@ func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName)
if db, err = sql.Open("postgres", dataSourceName); err != nil {
return nil, err
}
partitions := common.PartitionOffsetStatements{}
if err = partitions.Prepare(db); err != nil {
return nil, err
}
a := accountsStatements{}
if err = a.prepare(db, serverName); err != nil {
return nil, err
@ -46,7 +53,11 @@ func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName)
if err = p.prepare(db); err != nil {
return nil, err
}
return &Database{db, a, p}, nil
m := membershipStatements{}
if err = m.prepare(db); err != nil {
return nil, err
}
return &Database{db, partitions, a, p, m}, nil
}
// GetAccountByPassword returns the account associated with the given localpart and password.
@ -93,6 +104,30 @@ func (d *Database) CreateAccount(localpart, plaintextPassword string) (*authtype
return d.accounts.insertAccount(localpart, hash)
}
// PartitionOffsets implements common.PartitionStorer
func (d *Database) PartitionOffsets(topic string) ([]common.PartitionOffset, error) {
return d.partitions.SelectPartitionOffsets(topic)
}
// SetPartitionOffset implements common.PartitionStorer
func (d *Database) SetPartitionOffset(topic string, partition int32, offset int64) error {
return d.partitions.UpsertPartitionOffset(topic, partition, offset)
}
// SaveMembership saves the user matching a given localpart as a member of a given
// room. If a membership already exists between the user and the room, or of the
// insert fails, returns the SQL error
func (d *Database) SaveMembership(localpart string, roomID string) error {
return d.memberships.insertMembership(localpart, roomID)
}
// RemoveMembership removes the membership of the user mathing a given localpart
// from a given room.
// If the removal fails, or if there is no membership to remove, returns an error
func (d *Database) RemoveMembership(localpart string, roomID string) error {
return d.memberships.deleteMembership(localpart, roomID)
}
func hashPassword(plaintext string) (hash string, err error) {
hashBytes, err := bcrypt.GenerateFromPassword([]byte(plaintext), bcrypt.DefaultCost)
return string(hashBytes), err

View file

@ -0,0 +1,109 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package consumers
import (
"encoding/json"
"fmt"
"strings"
log "github.com/Sirupsen/logrus"
"github.com/matrix-org/dendrite/clientapi/auth/storage/accounts"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/common/config"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib"
sarama "gopkg.in/Shopify/sarama.v1"
)
// OutputRoomEvent consumes events that originated in the room server.
type OutputRoomEvent struct {
roomServerConsumer *common.ContinualConsumer
db *accounts.Database
query api.RoomserverQueryAPI
}
// NewOutputRoomEvent creates a new OutputRoomEvent consumer. Call Start() to begin consuming from room servers.
func NewOutputRoomEvent(cfg *config.Dendrite, store *accounts.Database) (*OutputRoomEvent, error) {
kafkaConsumer, err := sarama.NewConsumer(cfg.Kafka.Addresses, nil)
if err != nil {
return nil, err
}
roomServerURL := cfg.RoomServerURL()
consumer := common.ContinualConsumer{
Topic: string(cfg.Kafka.Topics.OutputRoomEvent),
Consumer: kafkaConsumer,
PartitionStore: store,
}
s := &OutputRoomEvent{
roomServerConsumer: &consumer,
db: store,
query: api.NewRoomserverQueryAPIHTTP(roomServerURL, nil),
}
consumer.ProcessMessage = s.onMessage
return s, nil
}
// Start consuming from room servers
func (s *OutputRoomEvent) Start() error {
return s.roomServerConsumer.Start()
}
// onMessage is called when the sync server receives a new event from the room server output log.
// It is not safe for this function to be called from multiple goroutines, or else the
// sync stream position may race and be incorrectly calculated.
func (s *OutputRoomEvent) onMessage(msg *sarama.ConsumerMessage) error {
// Parse out the event JSON
var output api.OutputRoomEvent
if err := json.Unmarshal(msg.Value, &output); err != nil {
// If the message was invalid, log it and move on to the next message in the stream
log.WithError(err).Errorf("roomserver output log: message parse failure")
return nil
}
ev, err := gomatrixserverlib.NewEventFromTrustedJSON(output.Event, false)
if err != nil {
log.WithError(err).Errorf("roomserver output log: event parse failure")
return nil
}
log.WithFields(log.Fields{
"event_id": ev.EventID(),
"room_id": ev.RoomID(),
"type": ev.Type(),
}).Info("received event from roomserver")
if ev.Type() == "m.room.member" && ev.StateKey() != nil {
localpart := getLocalPart(*ev.StateKey())
roomID := ev.RoomID()
if err := s.db.SaveMembership(localpart, roomID); err != nil {
return err
}
}
return nil
}
func getLocalPart(userID string) string {
if !strings.HasPrefix(userID, "@") {
panic(fmt.Errorf("Invalid user ID"))
}
// Get the part before ":"
username := strings.Split(userID, ":")[0]
// Return the part after the "@"
return strings.Split(username, "@")[1]
}

View file

@ -21,6 +21,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth/storage/accounts"
"github.com/matrix-org/dendrite/clientapi/auth/storage/devices"
"github.com/matrix-org/dendrite/clientapi/consumers"
"github.com/matrix-org/dendrite/clientapi/producers"
"github.com/matrix-org/dendrite/clientapi/routing"
"github.com/matrix-org/dendrite/common"
@ -86,6 +87,14 @@ func main() {
KeyDatabase: keyDB,
}
consumer, err := consumers.NewOutputRoomEvent(cfg, accountDB)
if err != nil {
log.Panicf("startup: failed to create room server consumer: %s", err)
}
if err = consumer.Start(); err != nil {
log.Panicf("startup: failed to start room server consumer")
}
log.Info("Starting client API server on ", cfg.Listen.ClientAPI)
routing.Setup(
http.DefaultServeMux, http.DefaultClient, *cfg, roomserverProducer,