mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-12 01:13:10 -06:00
Saving memberships
This commit is contained in:
parent
7d36ca03af
commit
b89a6b0fb6
|
|
@ -0,0 +1,86 @@
|
||||||
|
// Copyright 2017 Vector Creations Ltd
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package accounts
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
const membershipSchema = `
|
||||||
|
-- Stores data about accounts profiles.
|
||||||
|
CREATE TABLE IF NOT EXISTS memberships (
|
||||||
|
-- The Matrix user ID localpart for the member
|
||||||
|
localpart TEXT NOT NULL,
|
||||||
|
-- The room this user is a member of
|
||||||
|
room_id TEXT NOT NULL,
|
||||||
|
|
||||||
|
PRIMARY KEY (localpart, room_id)
|
||||||
|
);
|
||||||
|
|
||||||
|
-- A user can only be member of a room once
|
||||||
|
CREATE UNIQUE INDEX IF NOT EXISTS localpart_id_idx ON memberships(localpart, room_id);
|
||||||
|
`
|
||||||
|
|
||||||
|
const insertMembershipSQL = "" +
|
||||||
|
"INSERT INTO memberships(localpart, room_id) VALUES ($1, $2)"
|
||||||
|
|
||||||
|
const selectMembershipSQL = "" +
|
||||||
|
"SELECT * from memberships WHERE localpart = $1 AND room_id = $2"
|
||||||
|
|
||||||
|
const selectMembershipsByLocalpartSQL = "" +
|
||||||
|
"SELECT room_id FROM memberships WHERE localpart = $1"
|
||||||
|
|
||||||
|
const deleteMembershipSQL = "" +
|
||||||
|
"DELETE FROM memberships WHERE localpart = $1 AND room_id = $2"
|
||||||
|
|
||||||
|
type membershipStatements struct {
|
||||||
|
deleteMembershipStmt *sql.Stmt
|
||||||
|
insertMembershipStmt *sql.Stmt
|
||||||
|
selectMembershipsByLocalpartStmt *sql.Stmt
|
||||||
|
selectMembershipStmt *sql.Stmt
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *membershipStatements) prepare(db *sql.DB) (err error) {
|
||||||
|
_, err = db.Exec(membershipSchema)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.deleteMembershipStmt, err = db.Prepare(deleteMembershipSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.insertMembershipStmt, err = db.Prepare(insertMembershipSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.selectMembershipsByLocalpartStmt, err = db.Prepare(selectMembershipsByLocalpartSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.selectMembershipStmt, err = db.Prepare(selectMembershipSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *membershipStatements) insertMembership(localpart string, roomID string) (err error) {
|
||||||
|
fmt.Printf("Inserting membership for user %s and room %s\n", localpart, roomID)
|
||||||
|
_, err = s.insertMembershipStmt.Exec(localpart, roomID)
|
||||||
|
fmt.Println(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *membershipStatements) deleteMembership(localpart string, roomID string) (err error) {
|
||||||
|
_, err = s.deleteMembershipStmt.Exec(localpart, roomID)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
@ -18,6 +18,7 @@ import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
|
"github.com/matrix-org/dendrite/common"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
// Import the postgres database driver.
|
// Import the postgres database driver.
|
||||||
|
|
@ -27,8 +28,10 @@ import (
|
||||||
// Database represents an account database
|
// Database represents an account database
|
||||||
type Database struct {
|
type Database struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
|
partitions common.PartitionOffsetStatements
|
||||||
accounts accountsStatements
|
accounts accountsStatements
|
||||||
profiles profilesStatements
|
profiles profilesStatements
|
||||||
|
memberships membershipStatements
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDatabase creates a new accounts and profiles database
|
// NewDatabase creates a new accounts and profiles database
|
||||||
|
|
@ -38,6 +41,10 @@ func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName)
|
||||||
if db, err = sql.Open("postgres", dataSourceName); err != nil {
|
if db, err = sql.Open("postgres", dataSourceName); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
partitions := common.PartitionOffsetStatements{}
|
||||||
|
if err = partitions.Prepare(db); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
a := accountsStatements{}
|
a := accountsStatements{}
|
||||||
if err = a.prepare(db, serverName); err != nil {
|
if err = a.prepare(db, serverName); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
@ -46,7 +53,11 @@ func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName)
|
||||||
if err = p.prepare(db); err != nil {
|
if err = p.prepare(db); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &Database{db, a, p}, nil
|
m := membershipStatements{}
|
||||||
|
if err = m.prepare(db); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &Database{db, partitions, a, p, m}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountByPassword returns the account associated with the given localpart and password.
|
// GetAccountByPassword returns the account associated with the given localpart and password.
|
||||||
|
|
@ -93,6 +104,30 @@ func (d *Database) CreateAccount(localpart, plaintextPassword string) (*authtype
|
||||||
return d.accounts.insertAccount(localpart, hash)
|
return d.accounts.insertAccount(localpart, hash)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PartitionOffsets implements common.PartitionStorer
|
||||||
|
func (d *Database) PartitionOffsets(topic string) ([]common.PartitionOffset, error) {
|
||||||
|
return d.partitions.SelectPartitionOffsets(topic)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetPartitionOffset implements common.PartitionStorer
|
||||||
|
func (d *Database) SetPartitionOffset(topic string, partition int32, offset int64) error {
|
||||||
|
return d.partitions.UpsertPartitionOffset(topic, partition, offset)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveMembership saves the user matching a given localpart as a member of a given
|
||||||
|
// room. If a membership already exists between the user and the room, or of the
|
||||||
|
// insert fails, returns the SQL error
|
||||||
|
func (d *Database) SaveMembership(localpart string, roomID string) error {
|
||||||
|
return d.memberships.insertMembership(localpart, roomID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveMembership removes the membership of the user mathing a given localpart
|
||||||
|
// from a given room.
|
||||||
|
// If the removal fails, or if there is no membership to remove, returns an error
|
||||||
|
func (d *Database) RemoveMembership(localpart string, roomID string) error {
|
||||||
|
return d.memberships.deleteMembership(localpart, roomID)
|
||||||
|
}
|
||||||
|
|
||||||
func hashPassword(plaintext string) (hash string, err error) {
|
func hashPassword(plaintext string) (hash string, err error) {
|
||||||
hashBytes, err := bcrypt.GenerateFromPassword([]byte(plaintext), bcrypt.DefaultCost)
|
hashBytes, err := bcrypt.GenerateFromPassword([]byte(plaintext), bcrypt.DefaultCost)
|
||||||
return string(hashBytes), err
|
return string(hashBytes), err
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,109 @@
|
||||||
|
// Copyright 2017 Vector Creations Ltd
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package consumers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
log "github.com/Sirupsen/logrus"
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/auth/storage/accounts"
|
||||||
|
"github.com/matrix-org/dendrite/common"
|
||||||
|
"github.com/matrix-org/dendrite/common/config"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
sarama "gopkg.in/Shopify/sarama.v1"
|
||||||
|
)
|
||||||
|
|
||||||
|
// OutputRoomEvent consumes events that originated in the room server.
|
||||||
|
type OutputRoomEvent struct {
|
||||||
|
roomServerConsumer *common.ContinualConsumer
|
||||||
|
db *accounts.Database
|
||||||
|
query api.RoomserverQueryAPI
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewOutputRoomEvent creates a new OutputRoomEvent consumer. Call Start() to begin consuming from room servers.
|
||||||
|
func NewOutputRoomEvent(cfg *config.Dendrite, store *accounts.Database) (*OutputRoomEvent, error) {
|
||||||
|
kafkaConsumer, err := sarama.NewConsumer(cfg.Kafka.Addresses, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
roomServerURL := cfg.RoomServerURL()
|
||||||
|
|
||||||
|
consumer := common.ContinualConsumer{
|
||||||
|
Topic: string(cfg.Kafka.Topics.OutputRoomEvent),
|
||||||
|
Consumer: kafkaConsumer,
|
||||||
|
PartitionStore: store,
|
||||||
|
}
|
||||||
|
s := &OutputRoomEvent{
|
||||||
|
roomServerConsumer: &consumer,
|
||||||
|
db: store,
|
||||||
|
query: api.NewRoomserverQueryAPIHTTP(roomServerURL, nil),
|
||||||
|
}
|
||||||
|
consumer.ProcessMessage = s.onMessage
|
||||||
|
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start consuming from room servers
|
||||||
|
func (s *OutputRoomEvent) Start() error {
|
||||||
|
return s.roomServerConsumer.Start()
|
||||||
|
}
|
||||||
|
|
||||||
|
// onMessage is called when the sync server receives a new event from the room server output log.
|
||||||
|
// It is not safe for this function to be called from multiple goroutines, or else the
|
||||||
|
// sync stream position may race and be incorrectly calculated.
|
||||||
|
func (s *OutputRoomEvent) onMessage(msg *sarama.ConsumerMessage) error {
|
||||||
|
// Parse out the event JSON
|
||||||
|
var output api.OutputRoomEvent
|
||||||
|
if err := json.Unmarshal(msg.Value, &output); err != nil {
|
||||||
|
// If the message was invalid, log it and move on to the next message in the stream
|
||||||
|
log.WithError(err).Errorf("roomserver output log: message parse failure")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ev, err := gomatrixserverlib.NewEventFromTrustedJSON(output.Event, false)
|
||||||
|
if err != nil {
|
||||||
|
log.WithError(err).Errorf("roomserver output log: event parse failure")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"event_id": ev.EventID(),
|
||||||
|
"room_id": ev.RoomID(),
|
||||||
|
"type": ev.Type(),
|
||||||
|
}).Info("received event from roomserver")
|
||||||
|
|
||||||
|
if ev.Type() == "m.room.member" && ev.StateKey() != nil {
|
||||||
|
localpart := getLocalPart(*ev.StateKey())
|
||||||
|
roomID := ev.RoomID()
|
||||||
|
if err := s.db.SaveMembership(localpart, roomID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getLocalPart(userID string) string {
|
||||||
|
if !strings.HasPrefix(userID, "@") {
|
||||||
|
panic(fmt.Errorf("Invalid user ID"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the part before ":"
|
||||||
|
username := strings.Split(userID, ":")[0]
|
||||||
|
// Return the part after the "@"
|
||||||
|
return strings.Split(username, "@")[1]
|
||||||
|
}
|
||||||
|
|
@ -21,6 +21,7 @@ import (
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/auth/storage/accounts"
|
"github.com/matrix-org/dendrite/clientapi/auth/storage/accounts"
|
||||||
"github.com/matrix-org/dendrite/clientapi/auth/storage/devices"
|
"github.com/matrix-org/dendrite/clientapi/auth/storage/devices"
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/consumers"
|
||||||
"github.com/matrix-org/dendrite/clientapi/producers"
|
"github.com/matrix-org/dendrite/clientapi/producers"
|
||||||
"github.com/matrix-org/dendrite/clientapi/routing"
|
"github.com/matrix-org/dendrite/clientapi/routing"
|
||||||
"github.com/matrix-org/dendrite/common"
|
"github.com/matrix-org/dendrite/common"
|
||||||
|
|
@ -86,6 +87,14 @@ func main() {
|
||||||
KeyDatabase: keyDB,
|
KeyDatabase: keyDB,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
consumer, err := consumers.NewOutputRoomEvent(cfg, accountDB)
|
||||||
|
if err != nil {
|
||||||
|
log.Panicf("startup: failed to create room server consumer: %s", err)
|
||||||
|
}
|
||||||
|
if err = consumer.Start(); err != nil {
|
||||||
|
log.Panicf("startup: failed to start room server consumer")
|
||||||
|
}
|
||||||
|
|
||||||
log.Info("Starting client API server on ", cfg.Listen.ClientAPI)
|
log.Info("Starting client API server on ", cfg.Listen.ClientAPI)
|
||||||
routing.Setup(
|
routing.Setup(
|
||||||
http.DefaultServeMux, http.DefaultClient, *cfg, roomserverProducer,
|
http.DefaultServeMux, http.DefaultClient, *cfg, roomserverProducer,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue