Merge branch 'master' of github.com:matrix-org/dendrite into erikj/sync_since

This commit is contained in:
Erik Johnston 2017-11-21 13:37:20 +00:00
commit 3c8a06b088
13 changed files with 683 additions and 109 deletions

View file

@ -5,9 +5,9 @@
set -eu set -eu
# The mirror to download kafka from is picked from the list of mirrors at # The mirror to download kafka from is picked from the list of mirrors at
# https://www.apache.org/dyn/closer.cgi?path=/kafka/0.10.2.0/kafka_2.11-0.10.2.0.tgz # https://www.apache.org/dyn/closer.cgi?path=/kafka/0.10.2.0/kafka_2.11-0.11.0.2.tgz
# TODO: Check the signature since we are downloading over HTTP. # TODO: Check the signature since we are downloading over HTTP.
MIRROR=http://apache.mirror.anlx.net/kafka/0.10.2.0/kafka_2.11-0.10.2.0.tgz MIRROR=http://apache.mirror.anlx.net/kafka/0.11.0.2/kafka_2.11-0.11.0.2.tgz
# Only download the kafka if it isn't already downloaded. # Only download the kafka if it isn't already downloaded.
test -f kafka.tgz || wget $MIRROR -O kafka.tgz test -f kafka.tgz || wget $MIRROR -O kafka.tgz
@ -18,7 +18,7 @@ mkdir -p kafka && tar xzf kafka.tgz -C kafka --strip-components 1
kafka/bin/zookeeper-server-start.sh -daemon kafka/config/zookeeper.properties kafka/bin/zookeeper-server-start.sh -daemon kafka/config/zookeeper.properties
# Enable topic deletion so that the integration tests can create a fresh topic # Enable topic deletion so that the integration tests can create a fresh topic
# for each test run. # for each test run.
echo "delete.topic.enable=true" >> kafka/config/server.properties echo -e "\n\ndelete.topic.enable=true" >> kafka/config/server.properties
# Start the kafka server running in the background. # Start the kafka server running in the background.
# By default the kafka listens on localhost:9092 # By default the kafka listens on localhost:9092
kafka/bin/kafka-server-start.sh -daemon kafka/config/server.properties kafka/bin/kafka-server-start.sh -daemon kafka/config/server.properties

View file

@ -92,11 +92,11 @@ func (s *filterStatements) insertFilter(
// Check if filter already exists in the database // Check if filter already exists in the database
err = s.selectFilterIDByContentStmt.QueryRowContext(ctx, err = s.selectFilterIDByContentStmt.QueryRowContext(ctx,
localpart, filterJSON).Scan(&existingFilterID) localpart, filterJSON).Scan(&existingFilterID)
if err != nil { if err != nil && err != sql.ErrNoRows {
return "", err return "", err
} }
// If it does, return the existing ID // If it does, return the existing ID
if len(existingFilterID) != 0 { if existingFilterID != "" {
return existingFilterID, err return existingFilterID, err
} }

View file

@ -26,16 +26,10 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/common/config" "github.com/matrix-org/dendrite/common/config"
"github.com/matrix-org/gomatrix"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
type turnServerResponse struct {
Username string `json:"username"`
Password string `json:"password"`
URIs []string `json:"uris"`
TTL int `json:"ttl"`
}
// RequestTurnServer implements: // RequestTurnServer implements:
// GET /voip/turnServer // GET /voip/turnServer
func RequestTurnServer(req *http.Request, device *authtypes.Device, cfg config.Dendrite) util.JSONResponse { func RequestTurnServer(req *http.Request, device *authtypes.Device, cfg config.Dendrite) util.JSONResponse {
@ -52,7 +46,7 @@ func RequestTurnServer(req *http.Request, device *authtypes.Device, cfg config.D
// Duration checked at startup, err not possible // Duration checked at startup, err not possible
duration, _ := time.ParseDuration(turnConfig.UserLifetime) duration, _ := time.ParseDuration(turnConfig.UserLifetime)
resp := turnServerResponse{ resp := gomatrix.RespTurnServer{
URIs: turnConfig.URIs, URIs: turnConfig.URIs,
TTL: int(duration.Seconds()), TTL: int(duration.Seconds()),
} }

View file

@ -80,6 +80,7 @@ func main() {
queryAPI := api.NewRoomserverQueryAPIHTTP(cfg.RoomServerURL(), nil) queryAPI := api.NewRoomserverQueryAPIHTTP(cfg.RoomServerURL(), nil)
inputAPI := api.NewRoomserverInputAPIHTTP(cfg.RoomServerURL(), nil) inputAPI := api.NewRoomserverInputAPIHTTP(cfg.RoomServerURL(), nil)
aliasAPI := api.NewRoomserverAliasAPIHTTP(cfg.RoomServerURL(), nil)
roomserverProducer := producers.NewRoomserverProducer(inputAPI) roomserverProducer := producers.NewRoomserverProducer(inputAPI)
@ -90,7 +91,7 @@ func main() {
log.Info("Starting federation API server on ", cfg.Listen.FederationAPI) log.Info("Starting federation API server on ", cfg.Listen.FederationAPI)
api := mux.NewRouter() api := mux.NewRouter()
routing.Setup(api, *cfg, queryAPI, roomserverProducer, keyRing, federation, accountDB) routing.Setup(api, *cfg, queryAPI, aliasAPI, roomserverProducer, keyRing, federation, accountDB)
common.SetupHTTPAPI(http.DefaultServeMux, api) common.SetupHTTPAPI(http.DefaultServeMux, api)
log.Fatal(http.ListenAndServe(string(cfg.Listen.FederationAPI), nil)) log.Fatal(http.ListenAndServe(string(cfg.Listen.FederationAPI), nil))

View file

@ -16,6 +16,7 @@ package main
import ( import (
"context" "context"
"database/sql"
"flag" "flag"
"net/http" "net/http"
"os" "os"
@ -199,7 +200,21 @@ func (m *monolith) setupFederation() {
func (m *monolith) setupKafka() { func (m *monolith) setupKafka() {
if m.cfg.Kafka.UseNaffka { if m.cfg.Kafka.UseNaffka {
naff, err := naffka.New(&naffka.MemoryDatabase{}) db, err := sql.Open("postgres", string(m.cfg.Database.Naffka))
if err != nil {
log.WithFields(log.Fields{
log.ErrorKey: err,
}).Panic("Failed to open naffka database")
}
naffkaDB, err := naffka.NewPostgresqlDatabase(db)
if err != nil {
log.WithFields(log.Fields{
log.ErrorKey: err,
}).Panic("Failed to setup naffka database")
}
naff, err := naffka.New(naffkaDB)
if err != nil { if err != nil {
log.WithFields(log.Fields{ log.WithFields(log.Fields{
log.ErrorKey: err, log.ErrorKey: err,
@ -333,7 +348,7 @@ func (m *monolith) setupAPIs() {
), m.syncAPIDB, m.deviceDB) ), m.syncAPIDB, m.deviceDB)
federationapi_routing.Setup( federationapi_routing.Setup(
m.api, *m.cfg, m.queryAPI, m.roomServerProducer, m.keyRing, m.federation, m.api, *m.cfg, m.queryAPI, m.aliasAPI, m.roomServerProducer, m.keyRing, m.federation,
m.accountDB, m.accountDB,
) )

View file

@ -148,6 +148,8 @@ type Dendrite struct {
// The PublicRoomsAPI database stores information used to compute the public // The PublicRoomsAPI database stores information used to compute the public
// room directory. It is only accessed by the PublicRoomsAPI server. // room directory. It is only accessed by the PublicRoomsAPI server.
PublicRoomsAPI DataSource `yaml:"public_rooms_api"` PublicRoomsAPI DataSource `yaml:"public_rooms_api"`
// The Naffka database is used internally by the naffka library, if used.
Naffka DataSource `yaml:"naffka,omitempty"`
} `yaml:"database"` } `yaml:"database"`
// TURN Server Config // TURN Server Config
@ -386,6 +388,8 @@ func (config *Dendrite) check(monolithic bool) error {
if !monolithic { if !monolithic {
problems = append(problems, fmt.Sprintf("naffka can only be used in a monolithic server")) problems = append(problems, fmt.Sprintf("naffka can only be used in a monolithic server"))
} }
checkNotEmpty("database.naffka", string(config.Database.Naffka))
} else { } else {
// If we aren't using naffka then we need to have at least one kafka // If we aren't using naffka then we need to have at least one kafka
// server to talk to. // server to talk to.

View file

@ -0,0 +1,96 @@
// Copyright 2017 New Vector Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package routing
import (
"fmt"
"net/http"
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/common/config"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrix"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
)
// RoomAliasToID converts the queried alias into a room ID and returns it
func RoomAliasToID(
httpReq *http.Request,
federation *gomatrixserverlib.FederationClient,
cfg config.Dendrite,
aliasAPI api.RoomserverAliasAPI,
) util.JSONResponse {
roomAlias := httpReq.FormValue("alias")
if roomAlias == "" {
return util.JSONResponse{
Code: 400,
JSON: jsonerror.BadJSON("Must supply room alias parameter."),
}
}
_, domain, err := gomatrixserverlib.SplitID('#', roomAlias)
if err != nil {
return util.JSONResponse{
Code: 400,
JSON: jsonerror.BadJSON("Room alias must be in the form '#localpart:domain'"),
}
}
var resp gomatrixserverlib.RespDirectory
if domain == cfg.Matrix.ServerName {
queryReq := api.GetAliasRoomIDRequest{Alias: roomAlias}
var queryRes api.GetAliasRoomIDResponse
if err = aliasAPI.GetAliasRoomID(httpReq.Context(), &queryReq, &queryRes); err != nil {
return httputil.LogThenError(httpReq, err)
}
if queryRes.RoomID == "" {
// TODO: List servers that are aware of this room alias
resp = gomatrixserverlib.RespDirectory{
RoomID: queryRes.RoomID,
Servers: []gomatrixserverlib.ServerName{},
}
} else {
// If the response doesn't contain a non-empty string, return an error
return util.JSONResponse{
Code: 404,
JSON: jsonerror.NotFound(fmt.Sprintf("Room alias %s not found", roomAlias)),
}
}
} else {
resp, err = federation.LookupRoomAlias(httpReq.Context(), domain, roomAlias)
if err != nil {
switch x := err.(type) {
case gomatrix.HTTPError:
if x.Code == 404 {
return util.JSONResponse{
Code: 404,
JSON: jsonerror.NotFound("Room alias not found"),
}
}
}
// TODO: Return 502 if the remote server errored.
// TODO: Return 504 if the remote server timed out.
return httputil.LogThenError(httpReq, err)
}
}
return util.JSONResponse{
Code: 200,
JSON: resp,
}
}

View file

@ -38,6 +38,7 @@ func Setup(
apiMux *mux.Router, apiMux *mux.Router,
cfg config.Dendrite, cfg config.Dendrite,
query api.RoomserverQueryAPI, query api.RoomserverQueryAPI,
aliasAPI api.RoomserverAliasAPI,
producer *producers.RoomserverProducer, producer *producers.RoomserverProducer,
keys gomatrixserverlib.KeyRing, keys gomatrixserverlib.KeyRing,
federation *gomatrixserverlib.FederationClient, federation *gomatrixserverlib.FederationClient,
@ -105,6 +106,15 @@ func Setup(
}, },
)).Methods("GET") )).Methods("GET")
v1fedmux.Handle("/query/directory/", common.MakeFedAPI(
"federation_query_room_alias", cfg.Matrix.ServerName, keys,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse {
return RoomAliasToID(
httpReq, federation, cfg, aliasAPI,
)
},
)).Methods("GET")
v1fedmux.Handle("/query/profile", common.MakeFedAPI( v1fedmux.Handle("/query/profile", common.MakeFedAPI(
"federation_query_profile", cfg.Matrix.ServerName, keys, "federation_query_profile", cfg.Matrix.ServerName, keys,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse { func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse {

2
vendor/manifest vendored
View file

@ -141,7 +141,7 @@
{ {
"importpath": "github.com/matrix-org/naffka", "importpath": "github.com/matrix-org/naffka",
"repository": "https://github.com/matrix-org/naffka", "repository": "https://github.com/matrix-org/naffka",
"revision": "d28656e34f96a8eeaab53e3b7678c9ce14af5786", "revision": "662bfd0841d0194bfe0a700d54226bb96eac574d",
"branch": "master" "branch": "master"
}, },
{ {

View file

@ -8,7 +8,8 @@ import (
// A MemoryDatabase stores the message history as arrays in memory. // A MemoryDatabase stores the message history as arrays in memory.
// It can be used to run unit tests. // It can be used to run unit tests.
// If the process is stopped then any messages that haven't been // If the process is stopped then any messages that haven't been
// processed by a consumer are lost forever. // processed by a consumer are lost forever and all offsets become
// invalid.
type MemoryDatabase struct { type MemoryDatabase struct {
topicsMutex sync.Mutex topicsMutex sync.Mutex
topics map[string]*memoryDatabaseTopic topics map[string]*memoryDatabaseTopic
@ -58,10 +59,7 @@ func (m *MemoryDatabase) getTopic(topicName string) *memoryDatabaseTopic {
// StoreMessages implements Database // StoreMessages implements Database
func (m *MemoryDatabase) StoreMessages(topic string, messages []Message) error { func (m *MemoryDatabase) StoreMessages(topic string, messages []Message) error {
if err := m.getTopic(topic).addMessages(messages); err != nil { return m.getTopic(topic).addMessages(messages)
return err
}
return nil
} }
// FetchMessages implements Database // FetchMessages implements Database
@ -73,10 +71,10 @@ func (m *MemoryDatabase) FetchMessages(topic string, startOffset, endOffset int6
if startOffset >= endOffset { if startOffset >= endOffset {
return nil, fmt.Errorf("start offset %d greater than or equal to end offset %d", startOffset, endOffset) return nil, fmt.Errorf("start offset %d greater than or equal to end offset %d", startOffset, endOffset)
} }
if startOffset < -1 { if startOffset < 0 {
return nil, fmt.Errorf("start offset %d less than -1", startOffset) return nil, fmt.Errorf("start offset %d less than 0", startOffset)
} }
return messages[startOffset+1 : endOffset], nil return messages[startOffset:endOffset], nil
} }
// MaxOffsets implements Database // MaxOffsets implements Database

View file

@ -13,6 +13,7 @@ import (
// single go process. It implements both the sarama.SyncProducer and the // single go process. It implements both the sarama.SyncProducer and the
// sarama.Consumer interfaces. This means it can act as a drop in replacement // sarama.Consumer interfaces. This means it can act as a drop in replacement
// for kafka for testing or single instance deployment. // for kafka for testing or single instance deployment.
// Does not support multiple partitions.
type Naffka struct { type Naffka struct {
db Database db Database
topicsMutex sync.Mutex topicsMutex sync.Mutex
@ -28,6 +29,7 @@ func New(db Database) (*Naffka, error) {
} }
for topicName, offset := range maxOffsets { for topicName, offset := range maxOffsets {
n.topics[topicName] = &topic{ n.topics[topicName] = &topic{
db: db,
topicName: topicName, topicName: topicName,
nextOffset: offset + 1, nextOffset: offset + 1,
} }
@ -64,7 +66,7 @@ type Database interface {
// So for a given topic the message with offset n+1 is stored after the // So for a given topic the message with offset n+1 is stored after the
// the message with offset n. // the message with offset n.
StoreMessages(topic string, messages []Message) error StoreMessages(topic string, messages []Message) error
// FetchMessages fetches all messages with an offset greater than but not // FetchMessages fetches all messages with an offset greater than and
// including startOffset and less than but not including endOffset. // including startOffset and less than but not including endOffset.
// The range of offsets requested must not overlap with those stored by a // The range of offsets requested must not overlap with those stored by a
// concurrent StoreMessages. The message offsets within the requested range // concurrent StoreMessages. The message offsets within the requested range
@ -138,6 +140,7 @@ func (n *Naffka) Partitions(topic string) ([]int32, error) {
} }
// ConsumePartition implements sarama.Consumer // ConsumePartition implements sarama.Consumer
// Note: offset is *inclusive*, i.e. it will include the message with that offset.
func (n *Naffka) ConsumePartition(topic string, partition int32, offset int64) (sarama.PartitionConsumer, error) { func (n *Naffka) ConsumePartition(topic string, partition int32, offset int64) (sarama.PartitionConsumer, error) {
if partition != 0 { if partition != 0 {
return nil, fmt.Errorf("Unknown partition ID %d", partition) return nil, fmt.Errorf("Unknown partition ID %d", partition)
@ -166,13 +169,16 @@ func (n *Naffka) Close() error {
const channelSize = 1024 const channelSize = 1024
// partitionConsumer ensures that all messages written to a particular
// topic, from an offset, get sent in order to a channel.
// Implements sarama.PartitionConsumer
type partitionConsumer struct { type partitionConsumer struct {
topic *topic topic *topic
messages chan *sarama.ConsumerMessage messages chan *sarama.ConsumerMessage
// Whether the consumer is ready for new messages or whether it // Whether the consumer is in "catchup" mode or not.
// is catching up on historic messages. // See "catchup" function for details.
// Reads and writes to this field are proctected by the topic mutex. // Reads and writes to this field are proctected by the topic mutex.
ready bool catchingUp bool
} }
// AsyncClose implements sarama.PartitionConsumer // AsyncClose implements sarama.PartitionConsumer
@ -201,66 +207,101 @@ func (c *partitionConsumer) HighWaterMarkOffset() int64 {
return c.topic.highwaterMark() return c.topic.highwaterMark()
} }
// block writes the message to the consumer blocking until the consumer is ready // catchup makes the consumer go into "catchup" mode, where messages are read
// to add the message to the channel. Once the message is successfully added to // from the database instead of directly from producers.
// the channel it will catch up by pulling historic messsages from the database. // Once the consumer is up to date, i.e. no new messages in the database, then
func (c *partitionConsumer) block(cmsg *sarama.ConsumerMessage) { // the consumer will go back into normal mode where new messages are written
c.messages <- cmsg // directly to the channel.
c.catchup(cmsg.Offset) // Must be called with the c.topic.mutex lock
func (c *partitionConsumer) catchup(fromOffset int64) {
// If we're already in catchup mode or up to date, noop
if c.catchingUp || fromOffset == c.topic.nextOffset {
return
}
c.catchingUp = true
// Due to the checks above there can only be one of these goroutines
// running at a time
go func() {
for {
// Check if we're up to date yet. If we are we exit catchup mode.
c.topic.mutex.Lock()
nextOffset := c.topic.nextOffset
if fromOffset == nextOffset {
c.catchingUp = false
c.topic.mutex.Unlock()
return
}
c.topic.mutex.Unlock()
// Limit the number of messages we request from the database to be the
// capacity of the channel.
if nextOffset > fromOffset+int64(cap(c.messages)) {
nextOffset = fromOffset + int64(cap(c.messages))
}
// Fetch the messages from the database.
msgs, err := c.topic.db.FetchMessages(c.topic.topicName, fromOffset, nextOffset)
if err != nil {
// TODO: Add option to write consumer errors to an errors channel
// as an alternative to logging the errors.
log.Print("Error reading messages: ", err)
// Wait before retrying.
// TODO: Maybe use an exponentional backoff scheme here.
// TODO: This timeout should take account of all the other goroutines
// that might be doing the same thing. (If there are a 10000 consumers
// then we don't want to end up retrying every millisecond)
time.Sleep(10 * time.Second)
continue
}
if len(msgs) == 0 {
// This should only happen if the database is corrupted and has lost the
// messages between the requested offsets.
log.Fatalf("Corrupt database returned no messages between %d and %d", fromOffset, nextOffset)
}
// Pass the messages into the consumer channel.
// Blocking each write until the channel has enough space for the message.
for i := range msgs {
c.messages <- msgs[i].consumerMessage(c.topic.topicName)
}
// Update our the offset for the next loop iteration.
fromOffset = msgs[len(msgs)-1].Offset + 1
}
}()
} }
// catchup reads historic messages from the database until the consumer has caught // notifyNewMessage tells the consumer about a new message
// up on all the historic messages. // Must be called with the c.topic.mutex lock
func (c *partitionConsumer) catchup(fromOffset int64) { func (c *partitionConsumer) notifyNewMessage(cmsg *sarama.ConsumerMessage) {
for { // If we're in "catchup" mode then the catchup routine will send the
// First check if we have caught up. // message later, since cmsg has already been written to the database
caughtUp, nextOffset := c.topic.hasCaughtUp(c, fromOffset) if c.catchingUp {
if caughtUp { return
return }
}
// Limit the number of messages we request from the database to be the
// capacity of the channel.
if nextOffset > fromOffset+int64(cap(c.messages)) {
nextOffset = fromOffset + int64(cap(c.messages))
}
// Fetch the messages from the database.
msgs, err := c.topic.db.FetchMessages(c.topic.topicName, fromOffset, nextOffset)
if err != nil {
// TODO: Add option to write consumer errors to an errors channel
// as an alternative to logging the errors.
log.Print("Error reading messages: ", err)
// Wait before retrying.
// TODO: Maybe use an exponentional backoff scheme here.
// TODO: This timeout should take account of all the other goroutines
// that might be doing the same thing. (If there are a 10000 consumers
// then we don't want to end up retrying every millisecond)
time.Sleep(10 * time.Second)
continue
}
if len(msgs) == 0 {
// This should only happen if the database is corrupted and has lost the
// messages between the requested offsets.
log.Fatalf("Corrupt database returned no messages between %d and %d", fromOffset, nextOffset)
}
// Pass the messages into the consumer channel. // Otherwise, lets try writing the message directly to the channel
// Blocking each write until the channel has enough space for the message. select {
for i := range msgs { case c.messages <- cmsg:
c.messages <- msgs[i].consumerMessage(c.topic.topicName) default:
} // The messages channel has filled up, so lets go into catchup
// Update our the offset for the next loop iteration. // mode. Once the channel starts being read from again messages
fromOffset = msgs[len(msgs)-1].Offset // will be read from the database
c.catchup(cmsg.Offset)
} }
} }
type topic struct { type topic struct {
db Database db Database
topicName string topicName string
mutex sync.Mutex mutex sync.Mutex
consumers []*partitionConsumer consumers []*partitionConsumer
// nextOffset is the offset that will be assigned to the next message in
// this topic, i.e. one greater than the last message offset.
nextOffset int64 nextOffset int64
} }
// send writes messages to a topic.
func (t *topic) send(now time.Time, pmsgs []*sarama.ProducerMessage) error { func (t *topic) send(now time.Time, pmsgs []*sarama.ProducerMessage) error {
var err error var err error
// Encode the message keys and values. // Encode the message keys and values.
@ -298,21 +339,10 @@ func (t *topic) send(now time.Time, pmsgs []*sarama.ProducerMessage) error {
t.nextOffset = offset t.nextOffset = offset
// Now notify the consumers about the messages. // Now notify the consumers about the messages.
for i := range msgs { for _, msg := range msgs {
cmsg := msgs[i].consumerMessage(t.topicName) cmsg := msg.consumerMessage(t.topicName)
for _, c := range t.consumers { for _, c := range t.consumers {
if c.ready { c.notifyNewMessage(cmsg)
select {
case c.messages <- cmsg:
default:
// The consumer wasn't ready to receive a message because
// the channel buffer was full.
// Fork a goroutine to send the message so that we don't
// block sending messages to the other consumers.
c.ready = false
go c.block(cmsg)
}
}
} }
} }
@ -330,27 +360,17 @@ func (t *topic) consume(offset int64) *partitionConsumer {
offset = t.nextOffset offset = t.nextOffset
} }
if offset == sarama.OffsetOldest { if offset == sarama.OffsetOldest {
offset = -1 offset = 0
} }
c.messages = make(chan *sarama.ConsumerMessage, channelSize) c.messages = make(chan *sarama.ConsumerMessage, channelSize)
t.consumers = append(t.consumers, c) t.consumers = append(t.consumers, c)
// Start catching up on historic messages in the background.
go c.catchup(offset)
return c
}
func (t *topic) hasCaughtUp(c *partitionConsumer, offset int64) (bool, int64) { // If we're not streaming from the latest offset we need to go into
t.mutex.Lock() // "catchup" mode
defer t.mutex.Unlock() if offset != t.nextOffset {
// Check if we have caught up while holding a lock on the topic so there c.catchup(offset)
// isn't a way for our check to race with a new message being sent on the topic.
if offset+1 == t.nextOffset {
// We've caught up, the consumer can now receive messages as they are
// sent rather than fetching them from the database.
c.ready = true
return true, t.nextOffset
} }
return false, t.nextOffset return c
} }
func (t *topic) highwaterMark() int64 { func (t *topic) highwaterMark() int64 {

View file

@ -1,6 +1,7 @@
package naffka package naffka
import ( import (
"strconv"
"testing" "testing"
"time" "time"
@ -84,3 +85,142 @@ func TestDelayedReceive(t *testing.T) {
t.Fatalf("wrong value: wanted %q got %q", value, string(result.Value)) t.Fatalf("wrong value: wanted %q got %q", value, string(result.Value))
} }
} }
func TestCatchup(t *testing.T) {
naffka, err := New(&MemoryDatabase{})
if err != nil {
t.Fatal(err)
}
producer := sarama.SyncProducer(naffka)
consumer := sarama.Consumer(naffka)
const topic = "testTopic"
const value = "Hello, World"
message := sarama.ProducerMessage{
Value: sarama.StringEncoder(value),
Topic: topic,
}
if _, _, err = producer.SendMessage(&message); err != nil {
t.Fatal(err)
}
c, err := consumer.ConsumePartition(topic, 0, sarama.OffsetOldest)
if err != nil {
t.Fatal(err)
}
var result *sarama.ConsumerMessage
select {
case result = <-c.Messages():
case _ = <-time.NewTimer(10 * time.Second).C:
t.Fatal("expected to receive a message")
}
if string(result.Value) != value {
t.Fatalf("wrong value: wanted %q got %q", value, string(result.Value))
}
currOffset := result.Offset
const value2 = "Hello, World2"
const value3 = "Hello, World3"
_, _, err = producer.SendMessage(&sarama.ProducerMessage{
Value: sarama.StringEncoder(value2),
Topic: topic,
})
if err != nil {
t.Fatal(err)
}
_, _, err = producer.SendMessage(&sarama.ProducerMessage{
Value: sarama.StringEncoder(value3),
Topic: topic,
})
if err != nil {
t.Fatal(err)
}
t.Logf("Streaming from %q", currOffset+1)
c2, err := consumer.ConsumePartition(topic, 0, currOffset+1)
if err != nil {
t.Fatal(err)
}
var result2 *sarama.ConsumerMessage
select {
case result2 = <-c2.Messages():
case _ = <-time.NewTimer(10 * time.Second).C:
t.Fatal("expected to receive a message")
}
if string(result2.Value) != value2 {
t.Fatalf("wrong value: wanted %q got %q", value2, string(result2.Value))
}
}
func TestChannelSaturation(t *testing.T) {
// The channel returned by c.Messages() has a fixed capacity
naffka, err := New(&MemoryDatabase{})
if err != nil {
t.Fatal(err)
}
producer := sarama.SyncProducer(naffka)
consumer := sarama.Consumer(naffka)
const topic = "testTopic"
const baseValue = "testValue: "
c, err := consumer.ConsumePartition(topic, 0, sarama.OffsetOldest)
if err != nil {
t.Fatal(err)
}
channelSize := cap(c.Messages())
// We want to send enough messages to fill up the channel, so lets double
// the size of the channel. And add three in case its a zero sized channel
numberMessagesToSend := 2*channelSize + 3
var sentMessages []string
for i := 0; i < numberMessagesToSend; i++ {
value := baseValue + strconv.Itoa(i)
message := sarama.ProducerMessage{
Topic: topic,
Value: sarama.StringEncoder(value),
}
sentMessages = append(sentMessages, value)
if _, _, err = producer.SendMessage(&message); err != nil {
t.Fatal(err)
}
}
var result *sarama.ConsumerMessage
j := 0
for ; j < numberMessagesToSend; j++ {
select {
case result = <-c.Messages():
case _ = <-time.NewTimer(10 * time.Second).C:
t.Fatalf("failed to receive message %d out of %d", j+1, numberMessagesToSend)
}
expectedValue := sentMessages[j]
if string(result.Value) != expectedValue {
t.Fatalf("wrong value: wanted %q got %q", expectedValue, string(result.Value))
}
}
select {
case result = <-c.Messages():
t.Fatalf("expected to only receive %d messages", numberMessagesToSend)
default:
}
}

View file

@ -0,0 +1,296 @@
package naffka
import (
"database/sql"
"sync"
"time"
)
const postgresqlSchema = `
-- The topic table assigns each topic a unique numeric ID.
CREATE SEQUENCE IF NOT EXISTS naffka_topic_nid_seq;
CREATE TABLE IF NOT EXISTS naffka_topics (
topic_name TEXT PRIMARY KEY,
topic_nid BIGINT NOT NULL DEFAULT nextval('naffka_topic_nid_seq')
);
-- The messages table contains the actual messages.
CREATE TABLE IF NOT EXISTS naffka_messages (
topic_nid BIGINT NOT NULL,
message_offset BIGINT NOT NULL,
message_key BYTEA NOT NULL,
message_value BYTEA NOT NULL,
message_timestamp_ns BIGINT NOT NULL,
UNIQUE (topic_nid, message_offset)
);
`
const insertTopicSQL = "" +
"INSERT INTO naffka_topics (topic_name) VALUES ($1)" +
" ON CONFLICT DO NOTHING" +
" RETURNING (topic_nid)"
const selectTopicSQL = "" +
"SELECT topic_nid FROM naffka_topics WHERE topic_name = $1"
const selectTopicsSQL = "" +
"SELECT topic_name, topic_nid FROM naffka_topics"
const insertMessageSQL = "" +
"INSERT INTO naffka_messages (topic_nid, message_offset, message_key, message_value, message_timestamp_ns)" +
" VALUES ($1, $2, $3, $4, $5)"
const selectMessagesSQL = "" +
"SELECT message_offset, message_key, message_value, message_timestamp_ns" +
" FROM naffka_messages WHERE topic_nid = $1 AND $2 <= message_offset AND message_offset < $3" +
" ORDER BY message_offset ASC"
const selectMaxOffsetSQL = "" +
"SELECT message_offset FROM naffka_messages WHERE topic_nid = $1" +
" ORDER BY message_offset DESC LIMIT 1"
type postgresqlDatabase struct {
db *sql.DB
topicsMutex sync.Mutex
topicNIDs map[string]int64
insertTopicStmt *sql.Stmt
selectTopicStmt *sql.Stmt
selectTopicsStmt *sql.Stmt
insertMessageStmt *sql.Stmt
selectMessagesStmt *sql.Stmt
selectMaxOffsetStmt *sql.Stmt
}
// NewPostgresqlDatabase creates a new naffka database using a postgresql database.
// Returns an error if there was a problem setting up the database.
func NewPostgresqlDatabase(db *sql.DB) (Database, error) {
var err error
p := &postgresqlDatabase{
db: db,
topicNIDs: map[string]int64{},
}
if _, err = db.Exec(postgresqlSchema); err != nil {
return nil, err
}
for _, s := range []struct {
sql string
stmt **sql.Stmt
}{
{insertTopicSQL, &p.insertTopicStmt},
{selectTopicSQL, &p.selectTopicStmt},
{selectTopicsSQL, &p.selectTopicsStmt},
{insertMessageSQL, &p.insertMessageStmt},
{selectMessagesSQL, &p.selectMessagesStmt},
{selectMaxOffsetSQL, &p.selectMaxOffsetStmt},
} {
*s.stmt, err = db.Prepare(s.sql)
if err != nil {
return nil, err
}
}
return p, nil
}
// StoreMessages implements Database.
func (p *postgresqlDatabase) StoreMessages(topic string, messages []Message) error {
// Store the messages inside a single database transaction.
return withTransaction(p.db, func(txn *sql.Tx) error {
s := txn.Stmt(p.insertMessageStmt)
topicNID, err := p.assignTopicNID(txn, topic)
if err != nil {
return err
}
for _, m := range messages {
_, err = s.Exec(topicNID, m.Offset, m.Key, m.Value, m.Timestamp.UnixNano())
if err != nil {
return err
}
}
return nil
})
}
// FetchMessages implements Database.
func (p *postgresqlDatabase) FetchMessages(topic string, startOffset, endOffset int64) (messages []Message, err error) {
topicNID, err := p.getTopicNID(nil, topic)
if err != nil {
return
}
rows, err := p.selectMessagesStmt.Query(topicNID, startOffset, endOffset)
if err != nil {
return
}
defer rows.Close()
for rows.Next() {
var (
offset int64
key []byte
value []byte
timestampNano int64
)
if err = rows.Scan(&offset, &key, &value, &timestampNano); err != nil {
return
}
messages = append(messages, Message{
Offset: offset,
Key: key,
Value: value,
Timestamp: time.Unix(0, timestampNano),
})
}
return
}
// MaxOffsets implements Database.
func (p *postgresqlDatabase) MaxOffsets() (map[string]int64, error) {
topicNames, err := p.selectTopics()
if err != nil {
return nil, err
}
result := map[string]int64{}
for topicName, topicNID := range topicNames {
// Lookup the maximum offset.
maxOffset, err := p.selectMaxOffset(topicNID)
if err != nil {
return nil, err
}
if maxOffset > -1 {
// Don't include the topic if we haven't sent any messages on it.
result[topicName] = maxOffset
}
// Prefill the numeric ID cache.
p.addTopicNIDToCache(topicName, topicNID)
}
return result, nil
}
// selectTopics fetches the names and numeric IDs for all the topics the
// database is aware of.
func (p *postgresqlDatabase) selectTopics() (map[string]int64, error) {
rows, err := p.selectTopicsStmt.Query()
if err != nil {
return nil, err
}
defer rows.Close()
result := map[string]int64{}
for rows.Next() {
var (
topicName string
topicNID int64
)
if err = rows.Scan(&topicName, &topicNID); err != nil {
return nil, err
}
result[topicName] = topicNID
}
return result, nil
}
// selectMaxOffset selects the maximum offset for a topic.
// Returns -1 if there aren't any messages for that topic.
// Returns an error if there was a problem talking to the database.
func (p *postgresqlDatabase) selectMaxOffset(topicNID int64) (maxOffset int64, err error) {
err = p.selectMaxOffsetStmt.QueryRow(topicNID).Scan(&maxOffset)
if err == sql.ErrNoRows {
return -1, nil
}
return maxOffset, err
}
// getTopicNID finds the numeric ID for a topic.
// The txn argument is optional, this can be used outside a transaction
// by setting the txn argument to nil.
func (p *postgresqlDatabase) getTopicNID(txn *sql.Tx, topicName string) (topicNID int64, err error) {
// Get from the cache.
topicNID = p.getTopicNIDFromCache(topicName)
if topicNID != 0 {
return topicNID, nil
}
// Get from the database
s := p.selectTopicStmt
if txn != nil {
s = txn.Stmt(s)
}
err = s.QueryRow(topicName).Scan(&topicNID)
if err == sql.ErrNoRows {
return 0, nil
}
if err != nil {
return 0, err
}
// Update the shared cache.
p.addTopicNIDToCache(topicName, topicNID)
return topicNID, nil
}
// assignTopicNID assigns a new numeric ID to a topic.
// The txn argument is mandatory, this is always called inside a transaction.
func (p *postgresqlDatabase) assignTopicNID(txn *sql.Tx, topicName string) (topicNID int64, err error) {
// Check if we already have a numeric ID for the topic name.
topicNID, err = p.getTopicNID(txn, topicName)
if err != nil {
return 0, err
}
if topicNID != 0 {
return topicNID, err
}
// We don't have a numeric ID for the topic name so we add an entry to the
// topics table. If the insert stmt succeeds then it will return the ID.
err = txn.Stmt(p.insertTopicStmt).QueryRow(topicName).Scan(&topicNID)
if err == sql.ErrNoRows {
// If the insert stmt succeeded, but didn't return any rows then it
// means that someone has added a row for the topic name between us
// selecting it the first time and us inserting our own row.
// (N.B. postgres only returns modified rows when using "RETURNING")
// So we can now just select the row that someone else added.
// TODO: This is probably unnecessary since naffka writes to a topic
// from a single thread.
return p.getTopicNID(txn, topicName)
}
if err != nil {
return 0, err
}
// Update the cache.
p.addTopicNIDToCache(topicName, topicNID)
return topicNID, nil
}
// getTopicNIDFromCache returns the topicNID from the cache or returns 0 if the
// topic is not in the cache.
func (p *postgresqlDatabase) getTopicNIDFromCache(topicName string) (topicNID int64) {
p.topicsMutex.Lock()
defer p.topicsMutex.Unlock()
return p.topicNIDs[topicName]
}
// addTopicNIDToCache adds the numeric ID for the topic to the cache.
func (p *postgresqlDatabase) addTopicNIDToCache(topicName string, topicNID int64) {
p.topicsMutex.Lock()
defer p.topicsMutex.Unlock()
p.topicNIDs[topicName] = topicNID
}
// withTransaction runs a block of code passing in an SQL transaction
// If the code returns an error or panics then the transactions is rolledback
// Otherwise the transaction is committed.
func withTransaction(db *sql.DB, fn func(txn *sql.Tx) error) (err error) {
txn, err := db.Begin()
if err != nil {
return
}
defer func() {
if r := recover(); r != nil {
txn.Rollback()
panic(r)
} else if err != nil {
txn.Rollback()
} else {
err = txn.Commit()
}
}()
err = fn(txn)
return
}