Merge branch 'master' into profile-fed-query

This commit is contained in:
Alex Chen 2019-07-12 23:47:17 +08:00 committed by GitHub
commit 8e0b0bc7b1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
27 changed files with 807 additions and 284 deletions

View file

@ -117,12 +117,16 @@ func SetLocalAlias(
// 1. The new method for checking for things matching an AS's namespace // 1. The new method for checking for things matching an AS's namespace
// 2. Using an overall Regex object for all AS's just like we did for usernames // 2. Using an overall Regex object for all AS's just like we did for usernames
for _, appservice := range cfg.Derived.ApplicationServices { for _, appservice := range cfg.Derived.ApplicationServices {
if aliasNamespaces, ok := appservice.NamespaceMap["aliases"]; ok { // Don't prevent AS from creating aliases in its own namespace
for _, namespace := range aliasNamespaces { // Note that Dendrite uses SenderLocalpart as UserID for AS users
if namespace.Exclusive && namespace.RegexpObject.MatchString(alias) { if device.UserID != appservice.SenderLocalpart {
return util.JSONResponse{ if aliasNamespaces, ok := appservice.NamespaceMap["aliases"]; ok {
Code: http.StatusBadRequest, for _, namespace := range aliasNamespaces {
JSON: jsonerror.ASExclusive("Alias is reserved by an application service"), if namespace.Exclusive && namespace.RegexpObject.MatchString(alias) {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.ASExclusive("Alias is reserved by an application service"),
}
} }
} }
} }

View file

@ -62,7 +62,7 @@ func GetFilter(
filter := gomatrix.Filter{} filter := gomatrix.Filter{}
err = json.Unmarshal(res, &filter) err = json.Unmarshal(res, &filter)
if err != nil { if err != nil {
httputil.LogThenError(req, err) return httputil.LogThenError(req, err)
} }
return util.JSONResponse{ return util.JSONResponse{

View file

@ -107,7 +107,7 @@ func Login(
token, err := auth.GenerateAccessToken() token, err := auth.GenerateAccessToken()
if err != nil { if err != nil {
httputil.LogThenError(req, err) return httputil.LogThenError(req, err)
} }
dev, err := getDevice(req.Context(), r, deviceDB, acc, localpart, token) dev, err := getDevice(req.Context(), r, deviceDB, acc, localpart, token)

View file

@ -58,27 +58,12 @@ func SendMembership(
} }
} }
inviteStored, err := threepid.CheckAndProcessInvite( inviteStored, jsonErrResp := checkAndProcessThreepid(
req.Context(), device, &body, cfg, queryAPI, accountDB, producer, req, device, &body, cfg, queryAPI, accountDB, producer,
membership, roomID, evTime, membership, roomID, evTime,
) )
if err == threepid.ErrMissingParameter { if jsonErrResp != nil {
return util.JSONResponse{ return *jsonErrResp
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON(err.Error()),
}
} else if err == threepid.ErrNotTrusted {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.NotTrusted(body.IDServer),
}
} else if err == common.ErrRoomNoExists {
return util.JSONResponse{
Code: http.StatusNotFound,
JSON: jsonerror.NotFound(err.Error()),
}
} else if err != nil {
return httputil.LogThenError(req, err)
} }
// If an invite has been stored on an identity server, it means that a // If an invite has been stored on an identity server, it means that a
@ -114,9 +99,18 @@ func SendMembership(
return httputil.LogThenError(req, err) return httputil.LogThenError(req, err)
} }
var returnData interface{} = struct{}{}
// The join membership requires the room id to be sent in the response
if membership == "join" {
returnData = struct {
RoomID string `json:"room_id"`
}{roomID}
}
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
JSON: struct{}{}, JSON: returnData,
} }
} }
@ -215,3 +209,41 @@ func getMembershipStateKey(
return return
} }
func checkAndProcessThreepid(
req *http.Request,
device *authtypes.Device,
body *threepid.MembershipRequest,
cfg config.Dendrite,
queryAPI roomserverAPI.RoomserverQueryAPI,
accountDB *accounts.Database,
producer *producers.RoomserverProducer,
membership, roomID string,
evTime time.Time,
) (inviteStored bool, errRes *util.JSONResponse) {
inviteStored, err := threepid.CheckAndProcessInvite(
req.Context(), device, body, cfg, queryAPI, accountDB, producer,
membership, roomID, evTime,
)
if err == threepid.ErrMissingParameter {
return inviteStored, &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON(err.Error()),
}
} else if err == threepid.ErrNotTrusted {
return inviteStored, &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.NotTrusted(body.IDServer),
}
} else if err == common.ErrRoomNoExists {
return inviteStored, &util.JSONResponse{
Code: http.StatusNotFound,
JSON: jsonerror.NotFound(err.Error()),
}
} else if err != nil {
er := httputil.LogThenError(req, err)
return inviteStored, &er
}
return
}

View file

@ -243,8 +243,8 @@ func validateRecaptcha(
) *util.JSONResponse { ) *util.JSONResponse {
if !cfg.Matrix.RecaptchaEnabled { if !cfg.Matrix.RecaptchaEnabled {
return &util.JSONResponse{ return &util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusConflict,
JSON: jsonerror.BadJSON("Captcha registration is disabled"), JSON: jsonerror.Unknown("Captcha registration is disabled"),
} }
} }
@ -279,8 +279,8 @@ func validateRecaptcha(
body, err := ioutil.ReadAll(resp.Body) body, err := ioutil.ReadAll(resp.Body)
if err != nil { if err != nil {
return &util.JSONResponse{ return &util.JSONResponse{
Code: http.StatusInternalServerError, Code: http.StatusGatewayTimeout,
JSON: jsonerror.BadJSON("Error in contacting captcha server" + err.Error()), JSON: jsonerror.Unknown("Error in contacting captcha server" + err.Error()),
} }
} }
err = json.Unmarshal(body, &r) err = json.Unmarshal(body, &r)

View file

@ -305,6 +305,10 @@ func (r *downloadRequest) respondFromLocalFile(
}).Info("Responding with file") }).Info("Responding with file")
responseFile = file responseFile = file
responseMetadata = r.MediaMetadata responseMetadata = r.MediaMetadata
if len(responseMetadata.UploadName) > 0 {
w.Header().Set("Content-Disposition", fmt.Sprintf(`inline; filename*=utf-8"%s"`, responseMetadata.UploadName))
}
} }
w.Header().Set("Content-Type", string(responseMetadata.ContentType)) w.Header().Set("Content-Type", string(responseMetadata.ContentType))

View file

@ -277,6 +277,20 @@ func (r *RoomserverAliasAPI) SetupHTTP(servMux *http.ServeMux) {
return util.JSONResponse{Code: http.StatusOK, JSON: &response} return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}), }),
) )
servMux.Handle(
roomserverAPI.RoomserverGetAliasesForRoomIDPath,
common.MakeInternalAPI("getAliasesForRoomID", func(req *http.Request) util.JSONResponse {
var request roomserverAPI.GetAliasesForRoomIDRequest
var response roomserverAPI.GetAliasesForRoomIDResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.GetAliasesForRoomID(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
servMux.Handle( servMux.Handle(
roomserverAPI.RoomserverRemoveRoomAliasPath, roomserverAPI.RoomserverRemoveRoomAliasPath,
common.MakeInternalAPI("removeRoomAlias", func(req *http.Request) util.JSONResponse { common.MakeInternalAPI("removeRoomAlias", func(req *http.Request) util.JSONResponse {

View file

@ -22,7 +22,15 @@ then args="--fast"
fi fi
echo "Installing golangci-lint..." echo "Installing golangci-lint..."
# Make a backup of go.{mod,sum} first
# TODO: Once go 1.13 is out, use go get's -mod=readonly option
# https://github.com/golang/go/issues/30667
cp go.mod go.mod.bak && cp go.sum go.sum.bak
go get github.com/golangci/golangci-lint/cmd/golangci-lint go get github.com/golangci/golangci-lint/cmd/golangci-lint
echo "Looking for lint..." echo "Looking for lint..."
golangci-lint run $args golangci-lint run $args
# Restore go.{mod,sum}
mv go.mod.bak go.mod && mv go.sum.bak go.sum

View file

@ -22,6 +22,7 @@ import (
"github.com/matrix-org/dendrite/common/config" "github.com/matrix-org/dendrite/common/config"
"github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/sync" "github.com/matrix-org/dendrite/syncapi/sync"
"github.com/matrix-org/dendrite/syncapi/types"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
sarama "gopkg.in/Shopify/sarama.v1" sarama "gopkg.in/Shopify/sarama.v1"
) )
@ -29,7 +30,7 @@ import (
// OutputClientDataConsumer consumes events that originated in the client API server. // OutputClientDataConsumer consumes events that originated in the client API server.
type OutputClientDataConsumer struct { type OutputClientDataConsumer struct {
clientAPIConsumer *common.ContinualConsumer clientAPIConsumer *common.ContinualConsumer
db *storage.SyncServerDatabase db *storage.SyncServerDatasource
notifier *sync.Notifier notifier *sync.Notifier
} }
@ -38,7 +39,7 @@ func NewOutputClientDataConsumer(
cfg *config.Dendrite, cfg *config.Dendrite,
kafkaConsumer sarama.Consumer, kafkaConsumer sarama.Consumer,
n *sync.Notifier, n *sync.Notifier,
store *storage.SyncServerDatabase, store *storage.SyncServerDatasource,
) *OutputClientDataConsumer { ) *OutputClientDataConsumer {
consumer := common.ContinualConsumer{ consumer := common.ContinualConsumer{
@ -78,7 +79,7 @@ func (s *OutputClientDataConsumer) onMessage(msg *sarama.ConsumerMessage) error
"room_id": output.RoomID, "room_id": output.RoomID,
}).Info("received data from client API server") }).Info("received data from client API server")
syncStreamPos, err := s.db.UpsertAccountData( pduPos, err := s.db.UpsertAccountData(
context.TODO(), string(msg.Key), output.RoomID, output.Type, context.TODO(), string(msg.Key), output.RoomID, output.Type,
) )
if err != nil { if err != nil {
@ -89,7 +90,7 @@ func (s *OutputClientDataConsumer) onMessage(msg *sarama.ConsumerMessage) error
}).Panicf("could not save account data") }).Panicf("could not save account data")
} }
s.notifier.OnNewEvent(nil, string(msg.Key), syncStreamPos) s.notifier.OnNewEvent(nil, "", []string{string(msg.Key)}, types.SyncPosition{PDUPosition: pduPos})
return nil return nil
} }

View file

@ -33,7 +33,7 @@ import (
// OutputRoomEventConsumer consumes events that originated in the room server. // OutputRoomEventConsumer consumes events that originated in the room server.
type OutputRoomEventConsumer struct { type OutputRoomEventConsumer struct {
roomServerConsumer *common.ContinualConsumer roomServerConsumer *common.ContinualConsumer
db *storage.SyncServerDatabase db *storage.SyncServerDatasource
notifier *sync.Notifier notifier *sync.Notifier
query api.RoomserverQueryAPI query api.RoomserverQueryAPI
} }
@ -43,7 +43,7 @@ func NewOutputRoomEventConsumer(
cfg *config.Dendrite, cfg *config.Dendrite,
kafkaConsumer sarama.Consumer, kafkaConsumer sarama.Consumer,
n *sync.Notifier, n *sync.Notifier,
store *storage.SyncServerDatabase, store *storage.SyncServerDatasource,
queryAPI api.RoomserverQueryAPI, queryAPI api.RoomserverQueryAPI,
) *OutputRoomEventConsumer { ) *OutputRoomEventConsumer {
@ -126,7 +126,7 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent(
} }
} }
syncStreamPos, err := s.db.WriteEvent( pduPos, err := s.db.WriteEvent(
ctx, ctx,
&ev, &ev,
addsStateEvents, addsStateEvents,
@ -144,7 +144,7 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent(
}).Panicf("roomserver output log: write event failure") }).Panicf("roomserver output log: write event failure")
return nil return nil
} }
s.notifier.OnNewEvent(&ev, "", types.StreamPosition(syncStreamPos)) s.notifier.OnNewEvent(&ev, "", nil, types.SyncPosition{PDUPosition: pduPos})
return nil return nil
} }
@ -152,7 +152,7 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent(
func (s *OutputRoomEventConsumer) onNewInviteEvent( func (s *OutputRoomEventConsumer) onNewInviteEvent(
ctx context.Context, msg api.OutputNewInviteEvent, ctx context.Context, msg api.OutputNewInviteEvent,
) error { ) error {
syncStreamPos, err := s.db.AddInviteEvent(ctx, msg.Event) pduPos, err := s.db.AddInviteEvent(ctx, msg.Event)
if err != nil { if err != nil {
// panic rather than continue with an inconsistent database // panic rather than continue with an inconsistent database
log.WithFields(log.Fields{ log.WithFields(log.Fields{
@ -161,7 +161,7 @@ func (s *OutputRoomEventConsumer) onNewInviteEvent(
}).Panicf("roomserver output log: write invite failure") }).Panicf("roomserver output log: write invite failure")
return nil return nil
} }
s.notifier.OnNewEvent(&msg.Event, "", syncStreamPos) s.notifier.OnNewEvent(&msg.Event, "", nil, types.SyncPosition{PDUPosition: pduPos})
return nil return nil
} }

View file

@ -0,0 +1,96 @@
// Copyright 2019 Alex Chen
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package consumers
import (
"encoding/json"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/common/config"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/sync"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/dendrite/typingserver/api"
log "github.com/sirupsen/logrus"
sarama "gopkg.in/Shopify/sarama.v1"
)
// OutputTypingEventConsumer consumes events that originated in the typing server.
type OutputTypingEventConsumer struct {
typingConsumer *common.ContinualConsumer
db *storage.SyncServerDatasource
notifier *sync.Notifier
}
// NewOutputTypingEventConsumer creates a new OutputTypingEventConsumer.
// Call Start() to begin consuming from the typing server.
func NewOutputTypingEventConsumer(
cfg *config.Dendrite,
kafkaConsumer sarama.Consumer,
n *sync.Notifier,
store *storage.SyncServerDatasource,
) *OutputTypingEventConsumer {
consumer := common.ContinualConsumer{
Topic: string(cfg.Kafka.Topics.OutputTypingEvent),
Consumer: kafkaConsumer,
PartitionStore: store,
}
s := &OutputTypingEventConsumer{
typingConsumer: &consumer,
db: store,
notifier: n,
}
consumer.ProcessMessage = s.onMessage
return s
}
// Start consuming from typing api
func (s *OutputTypingEventConsumer) Start() error {
s.db.SetTypingTimeoutCallback(func(userID, roomID string, latestSyncPosition int64) {
s.notifier.OnNewEvent(nil, roomID, nil, types.SyncPosition{TypingPosition: latestSyncPosition})
})
return s.typingConsumer.Start()
}
func (s *OutputTypingEventConsumer) onMessage(msg *sarama.ConsumerMessage) error {
var output api.OutputTypingEvent
if err := json.Unmarshal(msg.Value, &output); err != nil {
// If the message was invalid, log it and move on to the next message in the stream
log.WithError(err).Errorf("typing server output log: message parse failure")
return nil
}
log.WithFields(log.Fields{
"room_id": output.Event.RoomID,
"user_id": output.Event.UserID,
"typing": output.Event.Typing,
}).Debug("received data from typing server")
var typingPos int64
typingEvent := output.Event
if typingEvent.Typing {
typingPos = s.db.AddTypingUser(typingEvent.UserID, typingEvent.RoomID, output.ExpireTime)
} else {
typingPos = s.db.RemoveTypingUser(typingEvent.UserID, typingEvent.RoomID)
}
s.notifier.OnNewEvent(nil, output.Event.RoomID, nil, types.SyncPosition{TypingPosition: typingPos})
return nil
}

View file

@ -34,7 +34,7 @@ const pathPrefixR0 = "/_matrix/client/r0"
// Due to Setup being used to call many other functions, a gocyclo nolint is // Due to Setup being used to call many other functions, a gocyclo nolint is
// applied: // applied:
// nolint: gocyclo // nolint: gocyclo
func Setup(apiMux *mux.Router, srp *sync.RequestPool, syncDB *storage.SyncServerDatabase, deviceDB *devices.Database) { func Setup(apiMux *mux.Router, srp *sync.RequestPool, syncDB *storage.SyncServerDatasource, deviceDB *devices.Database) {
r0mux := apiMux.PathPrefix(pathPrefixR0).Subrouter() r0mux := apiMux.PathPrefix(pathPrefixR0).Subrouter()
authData := auth.Data{ authData := auth.Data{

View file

@ -40,7 +40,7 @@ type stateEventInStateResp struct {
// TODO: Check if the user is in the room. If not, check if the room's history // TODO: Check if the user is in the room. If not, check if the room's history
// is publicly visible. Current behaviour is returning an empty array if the // is publicly visible. Current behaviour is returning an empty array if the
// user cannot see the room's history. // user cannot see the room's history.
func OnIncomingStateRequest(req *http.Request, db *storage.SyncServerDatabase, roomID string) util.JSONResponse { func OnIncomingStateRequest(req *http.Request, db *storage.SyncServerDatasource, roomID string) util.JSONResponse {
// TODO(#287): Auth request and handle the case where the user has left (where // TODO(#287): Auth request and handle the case where the user has left (where
// we should return the state at the poin they left) // we should return the state at the poin they left)
@ -84,7 +84,7 @@ func OnIncomingStateRequest(req *http.Request, db *storage.SyncServerDatabase, r
// /rooms/{roomID}/state/{type}/{statekey} request. It will look in current // /rooms/{roomID}/state/{type}/{statekey} request. It will look in current
// state to see if there is an event with that type and state key, if there // state to see if there is an event with that type and state key, if there
// is then (by default) we return the content, otherwise a 404. // is then (by default) we return the content, otherwise a 404.
func OnIncomingStateTypeRequest(req *http.Request, db *storage.SyncServerDatabase, roomID string, evType, stateKey string) util.JSONResponse { func OnIncomingStateTypeRequest(req *http.Request, db *storage.SyncServerDatasource, roomID string, evType, stateKey string) util.JSONResponse {
// TODO(#287): Auth request and handle the case where the user has left (where // TODO(#287): Auth request and handle the case where the user has left (where
// we should return the state at the poin they left) // we should return the state at the poin they left)

View file

@ -19,8 +19,6 @@ import (
"database/sql" "database/sql"
"github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/syncapi/types"
) )
const accountDataSchema = ` const accountDataSchema = `
@ -94,7 +92,7 @@ func (s *accountDataStatements) insertAccountData(
func (s *accountDataStatements) selectAccountDataInRange( func (s *accountDataStatements) selectAccountDataInRange(
ctx context.Context, ctx context.Context,
userID string, userID string,
oldPos, newPos types.StreamPosition, oldPos, newPos int64,
) (data map[string][]string, err error) { ) (data map[string][]string, err error) {
data = make(map[string][]string) data = make(map[string][]string)

View file

@ -23,7 +23,6 @@ import (
"github.com/lib/pq" "github.com/lib/pq"
"github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
@ -109,11 +108,11 @@ func (s *outputRoomEventsStatements) prepare(db *sql.DB) (err error) {
return return
} }
// selectStateInRange returns the state events between the two given stream positions, exclusive of oldPos, inclusive of newPos. // selectStateInRange returns the state events between the two given PDU stream positions, exclusive of oldPos, inclusive of newPos.
// Results are bucketed based on the room ID. If the same state is overwritten multiple times between the // Results are bucketed based on the room ID. If the same state is overwritten multiple times between the
// two positions, only the most recent state is returned. // two positions, only the most recent state is returned.
func (s *outputRoomEventsStatements) selectStateInRange( func (s *outputRoomEventsStatements) selectStateInRange(
ctx context.Context, txn *sql.Tx, oldPos, newPos types.StreamPosition, ctx context.Context, txn *sql.Tx, oldPos, newPos int64,
) (map[string]map[string]bool, map[string]streamEvent, error) { ) (map[string]map[string]bool, map[string]streamEvent, error) {
stmt := common.TxStmt(txn, s.selectStateInRangeStmt) stmt := common.TxStmt(txn, s.selectStateInRangeStmt)
@ -171,7 +170,7 @@ func (s *outputRoomEventsStatements) selectStateInRange(
eventIDToEvent[ev.EventID()] = streamEvent{ eventIDToEvent[ev.EventID()] = streamEvent{
Event: ev, Event: ev,
streamPosition: types.StreamPosition(streamPos), streamPosition: streamPos,
} }
} }
@ -223,7 +222,7 @@ func (s *outputRoomEventsStatements) insertEvent(
// RecentEventsInRoom returns the most recent events in the given room, up to a maximum of 'limit'. // RecentEventsInRoom returns the most recent events in the given room, up to a maximum of 'limit'.
func (s *outputRoomEventsStatements) selectRecentEvents( func (s *outputRoomEventsStatements) selectRecentEvents(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
roomID string, fromPos, toPos types.StreamPosition, limit int, roomID string, fromPos, toPos int64, limit int,
) ([]streamEvent, error) { ) ([]streamEvent, error) {
stmt := common.TxStmt(txn, s.selectRecentEventsStmt) stmt := common.TxStmt(txn, s.selectRecentEventsStmt)
rows, err := stmt.QueryContext(ctx, roomID, fromPos, toPos, limit) rows, err := stmt.QueryContext(ctx, roomID, fromPos, toPos, limit)
@ -286,7 +285,7 @@ func rowsToStreamEvents(rows *sql.Rows) ([]streamEvent, error) {
result = append(result, streamEvent{ result = append(result, streamEvent{
Event: ev, Event: ev,
streamPosition: types.StreamPosition(streamPos), streamPosition: streamPos,
transactionID: transactionID, transactionID: transactionID,
}) })
} }

View file

@ -17,7 +17,10 @@ package storage
import ( import (
"context" "context"
"database/sql" "database/sql"
"encoding/json"
"fmt" "fmt"
"strconv"
"time"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -28,6 +31,7 @@ import (
_ "github.com/lib/pq" _ "github.com/lib/pq"
"github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/dendrite/typingserver/cache"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -35,33 +39,35 @@ type stateDelta struct {
roomID string roomID string
stateEvents []gomatrixserverlib.Event stateEvents []gomatrixserverlib.Event
membership string membership string
// The stream position of the latest membership event for this user, if applicable. // The PDU stream position of the latest membership event for this user, if applicable.
// Can be 0 if there is no membership event in this delta. // Can be 0 if there is no membership event in this delta.
membershipPos types.StreamPosition membershipPos int64
} }
// Same as gomatrixserverlib.Event but also has the stream position for this event. // Same as gomatrixserverlib.Event but also has the PDU stream position for this event.
type streamEvent struct { type streamEvent struct {
gomatrixserverlib.Event gomatrixserverlib.Event
streamPosition types.StreamPosition streamPosition int64
transactionID *api.TransactionID transactionID *api.TransactionID
} }
// SyncServerDatabase represents a sync server database // SyncServerDatabase represents a sync server datasource which manages
type SyncServerDatabase struct { // both the database for PDUs and caches for EDUs.
type SyncServerDatasource struct {
db *sql.DB db *sql.DB
common.PartitionOffsetStatements common.PartitionOffsetStatements
accountData accountDataStatements accountData accountDataStatements
events outputRoomEventsStatements events outputRoomEventsStatements
roomstate currentRoomStateStatements roomstate currentRoomStateStatements
invites inviteEventsStatements invites inviteEventsStatements
typingCache *cache.TypingCache
} }
// NewSyncServerDatabase creates a new sync server database // NewSyncServerDatabase creates a new sync server database
func NewSyncServerDatabase(dataSourceName string) (*SyncServerDatabase, error) { func NewSyncServerDatasource(dbDataSourceName string) (*SyncServerDatasource, error) {
var d SyncServerDatabase var d SyncServerDatasource
var err error var err error
if d.db, err = sql.Open("postgres", dataSourceName); err != nil { if d.db, err = sql.Open("postgres", dbDataSourceName); err != nil {
return nil, err return nil, err
} }
if err = d.PartitionOffsetStatements.Prepare(d.db, "syncapi"); err != nil { if err = d.PartitionOffsetStatements.Prepare(d.db, "syncapi"); err != nil {
@ -79,11 +85,12 @@ func NewSyncServerDatabase(dataSourceName string) (*SyncServerDatabase, error) {
if err := d.invites.prepare(d.db); err != nil { if err := d.invites.prepare(d.db); err != nil {
return nil, err return nil, err
} }
d.typingCache = cache.NewTypingCache()
return &d, nil return &d, nil
} }
// AllJoinedUsersInRooms returns a map of room ID to a list of all joined user IDs. // AllJoinedUsersInRooms returns a map of room ID to a list of all joined user IDs.
func (d *SyncServerDatabase) AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) { func (d *SyncServerDatasource) AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) {
return d.roomstate.selectJoinedUsers(ctx) return d.roomstate.selectJoinedUsers(ctx)
} }
@ -92,7 +99,7 @@ func (d *SyncServerDatabase) AllJoinedUsersInRooms(ctx context.Context) (map[str
// If an event is not found in the database then it will be omitted from the list. // If an event is not found in the database then it will be omitted from the list.
// Returns an error if there was a problem talking with the database. // Returns an error if there was a problem talking with the database.
// Does not include any transaction IDs in the returned events. // Does not include any transaction IDs in the returned events.
func (d *SyncServerDatabase) Events(ctx context.Context, eventIDs []string) ([]gomatrixserverlib.Event, error) { func (d *SyncServerDatasource) Events(ctx context.Context, eventIDs []string) ([]gomatrixserverlib.Event, error) {
streamEvents, err := d.events.selectEvents(ctx, nil, eventIDs) streamEvents, err := d.events.selectEvents(ctx, nil, eventIDs)
if err != nil { if err != nil {
return nil, err return nil, err
@ -104,38 +111,38 @@ func (d *SyncServerDatabase) Events(ctx context.Context, eventIDs []string) ([]g
} }
// WriteEvent into the database. It is not safe to call this function from multiple goroutines, as it would create races // WriteEvent into the database. It is not safe to call this function from multiple goroutines, as it would create races
// when generating the stream position for this event. Returns the sync stream position for the inserted event. // when generating the sync stream position for this event. Returns the sync stream position for the inserted event.
// Returns an error if there was a problem inserting this event. // Returns an error if there was a problem inserting this event.
func (d *SyncServerDatabase) WriteEvent( func (d *SyncServerDatasource) WriteEvent(
ctx context.Context, ctx context.Context,
ev *gomatrixserverlib.Event, ev *gomatrixserverlib.Event,
addStateEvents []gomatrixserverlib.Event, addStateEvents []gomatrixserverlib.Event,
addStateEventIDs, removeStateEventIDs []string, addStateEventIDs, removeStateEventIDs []string,
transactionID *api.TransactionID, transactionID *api.TransactionID,
) (streamPos types.StreamPosition, returnErr error) { ) (pduPosition int64, returnErr error) {
returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error { returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
var err error var err error
pos, err := d.events.insertEvent(ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID) pos, err := d.events.insertEvent(ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID)
if err != nil { if err != nil {
return err return err
} }
streamPos = types.StreamPosition(pos) pduPosition = pos
if len(addStateEvents) == 0 && len(removeStateEventIDs) == 0 { if len(addStateEvents) == 0 && len(removeStateEventIDs) == 0 {
// Nothing to do, the event may have just been a message event. // Nothing to do, the event may have just been a message event.
return nil return nil
} }
return d.updateRoomState(ctx, txn, removeStateEventIDs, addStateEvents, streamPos) return d.updateRoomState(ctx, txn, removeStateEventIDs, addStateEvents, pduPosition)
}) })
return return
} }
func (d *SyncServerDatabase) updateRoomState( func (d *SyncServerDatasource) updateRoomState(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
removedEventIDs []string, removedEventIDs []string,
addedEvents []gomatrixserverlib.Event, addedEvents []gomatrixserverlib.Event,
streamPos types.StreamPosition, pduPosition int64,
) error { ) error {
// remove first, then add, as we do not ever delete state, but do replace state which is a remove followed by an add. // remove first, then add, as we do not ever delete state, but do replace state which is a remove followed by an add.
for _, eventID := range removedEventIDs { for _, eventID := range removedEventIDs {
@ -157,7 +164,7 @@ func (d *SyncServerDatabase) updateRoomState(
} }
membership = &value membership = &value
} }
if err := d.roomstate.upsertRoomState(ctx, txn, event, membership, int64(streamPos)); err != nil { if err := d.roomstate.upsertRoomState(ctx, txn, event, membership, pduPosition); err != nil {
return err return err
} }
} }
@ -168,7 +175,7 @@ func (d *SyncServerDatabase) updateRoomState(
// GetStateEvent returns the Matrix state event of a given type for a given room with a given state key // GetStateEvent returns the Matrix state event of a given type for a given room with a given state key
// If no event could be found, returns nil // If no event could be found, returns nil
// If there was an issue during the retrieval, returns an error // If there was an issue during the retrieval, returns an error
func (d *SyncServerDatabase) GetStateEvent( func (d *SyncServerDatasource) GetStateEvent(
ctx context.Context, roomID, evType, stateKey string, ctx context.Context, roomID, evType, stateKey string,
) (*gomatrixserverlib.Event, error) { ) (*gomatrixserverlib.Event, error) {
return d.roomstate.selectStateEvent(ctx, roomID, evType, stateKey) return d.roomstate.selectStateEvent(ctx, roomID, evType, stateKey)
@ -177,7 +184,7 @@ func (d *SyncServerDatabase) GetStateEvent(
// GetStateEventsForRoom fetches the state events for a given room. // GetStateEventsForRoom fetches the state events for a given room.
// Returns an empty slice if no state events could be found for this room. // Returns an empty slice if no state events could be found for this room.
// Returns an error if there was an issue with the retrieval. // Returns an error if there was an issue with the retrieval.
func (d *SyncServerDatabase) GetStateEventsForRoom( func (d *SyncServerDatasource) GetStateEventsForRoom(
ctx context.Context, roomID string, ctx context.Context, roomID string,
) (stateEvents []gomatrixserverlib.Event, err error) { ) (stateEvents []gomatrixserverlib.Event, err error) {
err = common.WithTransaction(d.db, func(txn *sql.Tx) error { err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
@ -187,46 +194,49 @@ func (d *SyncServerDatabase) GetStateEventsForRoom(
return return
} }
// SyncStreamPosition returns the latest position in the sync stream. Returns 0 if there are no events yet. // SyncPosition returns the latest positions for syncing.
func (d *SyncServerDatabase) SyncStreamPosition(ctx context.Context) (types.StreamPosition, error) { func (d *SyncServerDatasource) SyncPosition(ctx context.Context) (types.SyncPosition, error) {
return d.syncStreamPositionTx(ctx, nil) return d.syncPositionTx(ctx, nil)
} }
func (d *SyncServerDatabase) syncStreamPositionTx( func (d *SyncServerDatasource) syncPositionTx(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
) (types.StreamPosition, error) { ) (sp types.SyncPosition, err error) {
maxID, err := d.events.selectMaxEventID(ctx, txn)
maxEventID, err := d.events.selectMaxEventID(ctx, txn)
if err != nil { if err != nil {
return 0, err return sp, err
} }
maxAccountDataID, err := d.accountData.selectMaxAccountDataID(ctx, txn) maxAccountDataID, err := d.accountData.selectMaxAccountDataID(ctx, txn)
if err != nil { if err != nil {
return 0, err return sp, err
} }
if maxAccountDataID > maxID { if maxAccountDataID > maxEventID {
maxID = maxAccountDataID maxEventID = maxAccountDataID
} }
maxInviteID, err := d.invites.selectMaxInviteID(ctx, txn) maxInviteID, err := d.invites.selectMaxInviteID(ctx, txn)
if err != nil { if err != nil {
return 0, err return sp, err
} }
if maxInviteID > maxID { if maxInviteID > maxEventID {
maxID = maxInviteID maxEventID = maxInviteID
} }
return types.StreamPosition(maxID), nil sp.PDUPosition = maxEventID
sp.TypingPosition = d.typingCache.GetLatestSyncPosition()
return
} }
// IncrementalSync returns all the data needed in order to create an incremental // addPDUDeltaToResponse adds all PDU deltas to a sync response.
// sync response for the given user. Events returned will include any client // IDs of all rooms the user joined are returned so EDU deltas can be added for them.
// transaction IDs associated with the given device. These transaction IDs come func (d *SyncServerDatasource) addPDUDeltaToResponse(
// from when the device sent the event via an API that included a transaction
// ID.
func (d *SyncServerDatabase) IncrementalSync(
ctx context.Context, ctx context.Context,
device authtypes.Device, device authtypes.Device,
fromPos, toPos types.StreamPosition, fromPos, toPos int64,
numRecentEventsPerRoom int, numRecentEventsPerRoom int,
) (*types.Response, error) { res *types.Response,
) ([]string, error) {
txn, err := d.db.BeginTx(ctx, &txReadOnlySnapshot) txn, err := d.db.BeginTx(ctx, &txReadOnlySnapshot)
if err != nil { if err != nil {
return nil, err return nil, err
@ -235,7 +245,7 @@ func (d *SyncServerDatabase) IncrementalSync(
defer common.EndTransaction(txn, &succeeded) defer common.EndTransaction(txn, &succeeded)
// Work out which rooms to return in the response. This is done by getting not only the currently // Work out which rooms to return in the response. This is done by getting not only the currently
// joined rooms, but also which rooms have membership transitions for this user between the 2 stream positions. // joined rooms, but also which rooms have membership transitions for this user between the 2 PDU stream positions.
// This works out what the 'state' key should be for each room as well as which membership block // This works out what the 'state' key should be for each room as well as which membership block
// to put the room into. // to put the room into.
deltas, err := d.getStateDeltas(ctx, &device, txn, fromPos, toPos, device.UserID) deltas, err := d.getStateDeltas(ctx, &device, txn, fromPos, toPos, device.UserID)
@ -243,8 +253,9 @@ func (d *SyncServerDatabase) IncrementalSync(
return nil, err return nil, err
} }
res := types.NewResponse(toPos) joinedRoomIDs := make([]string, 0, len(deltas))
for _, delta := range deltas { for _, delta := range deltas {
joinedRoomIDs = append(joinedRoomIDs, delta.roomID)
err = d.addRoomDeltaToResponse(ctx, &device, txn, fromPos, toPos, delta, numRecentEventsPerRoom, res) err = d.addRoomDeltaToResponse(ctx, &device, txn, fromPos, toPos, delta, numRecentEventsPerRoom, res)
if err != nil { if err != nil {
return nil, err return nil, err
@ -257,52 +268,151 @@ func (d *SyncServerDatabase) IncrementalSync(
} }
succeeded = true succeeded = true
return joinedRoomIDs, nil
}
// addTypingDeltaToResponse adds all typing notifications to a sync response
// since the specified position.
func (d *SyncServerDatasource) addTypingDeltaToResponse(
since int64,
joinedRoomIDs []string,
res *types.Response,
) error {
var jr types.JoinResponse
var ok bool
var err error
for _, roomID := range joinedRoomIDs {
if typingUsers, updated := d.typingCache.GetTypingUsersIfUpdatedAfter(
roomID, since,
); updated {
ev := gomatrixserverlib.ClientEvent{
Type: gomatrixserverlib.MTyping,
}
ev.Content, err = json.Marshal(map[string]interface{}{
"user_ids": typingUsers,
})
if err != nil {
return err
}
if jr, ok = res.Rooms.Join[roomID]; !ok {
jr = *types.NewJoinResponse()
}
jr.Ephemeral.Events = append(jr.Ephemeral.Events, ev)
res.Rooms.Join[roomID] = jr
}
}
return nil
}
// addEDUDeltaToResponse adds updates for EDUs of each type since fromPos if
// the positions of that type are not equal in fromPos and toPos.
func (d *SyncServerDatasource) addEDUDeltaToResponse(
fromPos, toPos types.SyncPosition,
joinedRoomIDs []string,
res *types.Response,
) (err error) {
if fromPos.TypingPosition != toPos.TypingPosition {
err = d.addTypingDeltaToResponse(
fromPos.TypingPosition, joinedRoomIDs, res,
)
}
return
}
// IncrementalSync returns all the data needed in order to create an incremental
// sync response for the given user. Events returned will include any client
// transaction IDs associated with the given device. These transaction IDs come
// from when the device sent the event via an API that included a transaction
// ID.
func (d *SyncServerDatasource) IncrementalSync(
ctx context.Context,
device authtypes.Device,
fromPos, toPos types.SyncPosition,
numRecentEventsPerRoom int,
) (*types.Response, error) {
nextBatchPos := fromPos.WithUpdates(toPos)
res := types.NewResponse(nextBatchPos)
var joinedRoomIDs []string
var err error
if fromPos.PDUPosition != toPos.PDUPosition {
joinedRoomIDs, err = d.addPDUDeltaToResponse(
ctx, device, fromPos.PDUPosition, toPos.PDUPosition, numRecentEventsPerRoom, res,
)
} else {
joinedRoomIDs, err = d.roomstate.selectRoomIDsWithMembership(
ctx, nil, device.UserID, "join",
)
}
if err != nil {
return nil, err
}
err = d.addEDUDeltaToResponse(
fromPos, toPos, joinedRoomIDs, res,
)
if err != nil {
return nil, err
}
return res, nil return res, nil
} }
// CompleteSync a complete /sync API response for the given user. // getResponseWithPDUsForCompleteSync creates a response and adds all PDUs needed
func (d *SyncServerDatabase) CompleteSync( // to it. It returns toPos and joinedRoomIDs for use of adding EDUs.
ctx context.Context, userID string, numRecentEventsPerRoom int, func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync(
) (*types.Response, error) { ctx context.Context,
userID string,
numRecentEventsPerRoom int,
) (
res *types.Response,
toPos types.SyncPosition,
joinedRoomIDs []string,
err error,
) {
// This needs to be all done in a transaction as we need to do multiple SELECTs, and we need to have // This needs to be all done in a transaction as we need to do multiple SELECTs, and we need to have
// a consistent view of the database throughout. This includes extracting the sync stream position. // a consistent view of the database throughout. This includes extracting the sync position.
// This does have the unfortunate side-effect that all the matrixy logic resides in this function, // This does have the unfortunate side-effect that all the matrixy logic resides in this function,
// but it's better to not hide the fact that this is being done in a transaction. // but it's better to not hide the fact that this is being done in a transaction.
txn, err := d.db.BeginTx(ctx, &txReadOnlySnapshot) txn, err := d.db.BeginTx(ctx, &txReadOnlySnapshot)
if err != nil { if err != nil {
return nil, err return
} }
var succeeded bool var succeeded bool
defer common.EndTransaction(txn, &succeeded) defer common.EndTransaction(txn, &succeeded)
// Get the current stream position which we will base the sync response on. // Get the current sync position which we will base the sync response on.
pos, err := d.syncStreamPositionTx(ctx, txn) toPos, err = d.syncPositionTx(ctx, txn)
if err != nil { if err != nil {
return nil, err return
} }
res = types.NewResponse(toPos)
// Extract room state and recent events for all rooms the user is joined to. // Extract room state and recent events for all rooms the user is joined to.
roomIDs, err := d.roomstate.selectRoomIDsWithMembership(ctx, txn, userID, "join") joinedRoomIDs, err = d.roomstate.selectRoomIDsWithMembership(ctx, txn, userID, "join")
if err != nil { if err != nil {
return nil, err return
} }
// Build up a /sync response. Add joined rooms. // Build up a /sync response. Add joined rooms.
res := types.NewResponse(pos) for _, roomID := range joinedRoomIDs {
for _, roomID := range roomIDs {
var stateEvents []gomatrixserverlib.Event var stateEvents []gomatrixserverlib.Event
stateEvents, err = d.roomstate.selectCurrentState(ctx, txn, roomID) stateEvents, err = d.roomstate.selectCurrentState(ctx, txn, roomID)
if err != nil { if err != nil {
return nil, err return
} }
// TODO: When filters are added, we may need to call this multiple times to get enough events. // TODO: When filters are added, we may need to call this multiple times to get enough events.
// See: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L316 // See: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L316
var recentStreamEvents []streamEvent var recentStreamEvents []streamEvent
recentStreamEvents, err = d.events.selectRecentEvents( recentStreamEvents, err = d.events.selectRecentEvents(
ctx, txn, roomID, types.StreamPosition(0), pos, numRecentEventsPerRoom, ctx, txn, roomID, 0, toPos.PDUPosition, numRecentEventsPerRoom,
) )
if err != nil { if err != nil {
return nil, err return
} }
// We don't include a device here as we don't need to send down // We don't include a device here as we don't need to send down
@ -311,10 +421,12 @@ func (d *SyncServerDatabase) CompleteSync(
stateEvents = removeDuplicates(stateEvents, recentEvents) stateEvents = removeDuplicates(stateEvents, recentEvents)
jr := types.NewJoinResponse() jr := types.NewJoinResponse()
if prevBatch := recentStreamEvents[0].streamPosition - 1; prevBatch > 0 { if prevPDUPos := recentStreamEvents[0].streamPosition - 1; prevPDUPos > 0 {
jr.Timeline.PrevBatch = types.StreamPosition(prevBatch).String() // Use the short form of batch token for prev_batch
jr.Timeline.PrevBatch = strconv.FormatInt(prevPDUPos, 10)
} else { } else {
jr.Timeline.PrevBatch = types.StreamPosition(1).String() // Use the short form of batch token for prev_batch
jr.Timeline.PrevBatch = "1"
} }
jr.Timeline.Events = gomatrixserverlib.ToClientEvents(recentEvents, gomatrixserverlib.FormatSync) jr.Timeline.Events = gomatrixserverlib.ToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
jr.Timeline.Limited = true jr.Timeline.Limited = true
@ -322,12 +434,34 @@ func (d *SyncServerDatabase) CompleteSync(
res.Rooms.Join[roomID] = *jr res.Rooms.Join[roomID] = *jr
} }
if err = d.addInvitesToResponse(ctx, txn, userID, 0, pos, res); err != nil { if err = d.addInvitesToResponse(ctx, txn, userID, 0, toPos.PDUPosition, res); err != nil {
return nil, err return
} }
succeeded = true succeeded = true
return res, err return res, toPos, joinedRoomIDs, err
}
// CompleteSync returns a complete /sync API response for the given user.
func (d *SyncServerDatasource) CompleteSync(
ctx context.Context, userID string, numRecentEventsPerRoom int,
) (*types.Response, error) {
res, toPos, joinedRoomIDs, err := d.getResponseWithPDUsForCompleteSync(
ctx, userID, numRecentEventsPerRoom,
)
if err != nil {
return nil, err
}
// Use a zero value SyncPosition for fromPos so all EDU states are added.
err = d.addEDUDeltaToResponse(
types.SyncPosition{}, toPos, joinedRoomIDs, res,
)
if err != nil {
return nil, err
}
return res, nil
} }
var txReadOnlySnapshot = sql.TxOptions{ var txReadOnlySnapshot = sql.TxOptions{
@ -345,8 +479,8 @@ var txReadOnlySnapshot = sql.TxOptions{
// Returns a map following the format data[roomID] = []dataTypes // Returns a map following the format data[roomID] = []dataTypes
// If no data is retrieved, returns an empty map // If no data is retrieved, returns an empty map
// If there was an issue with the retrieval, returns an error // If there was an issue with the retrieval, returns an error
func (d *SyncServerDatabase) GetAccountDataInRange( func (d *SyncServerDatasource) GetAccountDataInRange(
ctx context.Context, userID string, oldPos, newPos types.StreamPosition, ctx context.Context, userID string, oldPos, newPos int64,
) (map[string][]string, error) { ) (map[string][]string, error) {
return d.accountData.selectAccountDataInRange(ctx, userID, oldPos, newPos) return d.accountData.selectAccountDataInRange(ctx, userID, oldPos, newPos)
} }
@ -357,26 +491,24 @@ func (d *SyncServerDatabase) GetAccountDataInRange(
// If no data with the given type, user ID and room ID exists in the database, // If no data with the given type, user ID and room ID exists in the database,
// creates a new row, else update the existing one // creates a new row, else update the existing one
// Returns an error if there was an issue with the upsert // Returns an error if there was an issue with the upsert
func (d *SyncServerDatabase) UpsertAccountData( func (d *SyncServerDatasource) UpsertAccountData(
ctx context.Context, userID, roomID, dataType string, ctx context.Context, userID, roomID, dataType string,
) (types.StreamPosition, error) { ) (int64, error) {
pos, err := d.accountData.insertAccountData(ctx, userID, roomID, dataType) return d.accountData.insertAccountData(ctx, userID, roomID, dataType)
return types.StreamPosition(pos), err
} }
// AddInviteEvent stores a new invite event for a user. // AddInviteEvent stores a new invite event for a user.
// If the invite was successfully stored this returns the stream ID it was stored at. // If the invite was successfully stored this returns the stream ID it was stored at.
// Returns an error if there was a problem communicating with the database. // Returns an error if there was a problem communicating with the database.
func (d *SyncServerDatabase) AddInviteEvent( func (d *SyncServerDatasource) AddInviteEvent(
ctx context.Context, inviteEvent gomatrixserverlib.Event, ctx context.Context, inviteEvent gomatrixserverlib.Event,
) (types.StreamPosition, error) { ) (int64, error) {
pos, err := d.invites.insertInviteEvent(ctx, inviteEvent) return d.invites.insertInviteEvent(ctx, inviteEvent)
return types.StreamPosition(pos), err
} }
// RetireInviteEvent removes an old invite event from the database. // RetireInviteEvent removes an old invite event from the database.
// Returns an error if there was a problem communicating with the database. // Returns an error if there was a problem communicating with the database.
func (d *SyncServerDatabase) RetireInviteEvent( func (d *SyncServerDatasource) RetireInviteEvent(
ctx context.Context, inviteEventID string, ctx context.Context, inviteEventID string,
) error { ) error {
// TODO: Record that invite has been retired in a stream so that we can // TODO: Record that invite has been retired in a stream so that we can
@ -385,10 +517,30 @@ func (d *SyncServerDatabase) RetireInviteEvent(
return err return err
} }
func (d *SyncServerDatabase) addInvitesToResponse( func (d *SyncServerDatasource) SetTypingTimeoutCallback(fn cache.TimeoutCallbackFn) {
d.typingCache.SetTimeoutCallback(fn)
}
// AddTypingUser adds a typing user to the typing cache.
// Returns the newly calculated sync position for typing notifications.
func (d *SyncServerDatasource) AddTypingUser(
userID, roomID string, expireTime *time.Time,
) int64 {
return d.typingCache.AddTypingUser(userID, roomID, expireTime)
}
// RemoveTypingUser removes a typing user from the typing cache.
// Returns the newly calculated sync position for typing notifications.
func (d *SyncServerDatasource) RemoveTypingUser(
userID, roomID string,
) int64 {
return d.typingCache.RemoveUser(userID, roomID)
}
func (d *SyncServerDatasource) addInvitesToResponse(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
userID string, userID string,
fromPos, toPos types.StreamPosition, fromPos, toPos int64,
res *types.Response, res *types.Response,
) error { ) error {
invites, err := d.invites.selectInviteEventsInRange( invites, err := d.invites.selectInviteEventsInRange(
@ -409,11 +561,11 @@ func (d *SyncServerDatabase) addInvitesToResponse(
} }
// addRoomDeltaToResponse adds a room state delta to a sync response // addRoomDeltaToResponse adds a room state delta to a sync response
func (d *SyncServerDatabase) addRoomDeltaToResponse( func (d *SyncServerDatasource) addRoomDeltaToResponse(
ctx context.Context, ctx context.Context,
device *authtypes.Device, device *authtypes.Device,
txn *sql.Tx, txn *sql.Tx,
fromPos, toPos types.StreamPosition, fromPos, toPos int64,
delta stateDelta, delta stateDelta,
numRecentEventsPerRoom int, numRecentEventsPerRoom int,
res *types.Response, res *types.Response,
@ -445,10 +597,12 @@ func (d *SyncServerDatabase) addRoomDeltaToResponse(
switch delta.membership { switch delta.membership {
case "join": case "join":
jr := types.NewJoinResponse() jr := types.NewJoinResponse()
if prevBatch := recentStreamEvents[0].streamPosition - 1; prevBatch > 0 { if prevPDUPos := recentStreamEvents[0].streamPosition - 1; prevPDUPos > 0 {
jr.Timeline.PrevBatch = types.StreamPosition(prevBatch).String() // Use the short form of batch token for prev_batch
jr.Timeline.PrevBatch = strconv.FormatInt(prevPDUPos, 10)
} else { } else {
jr.Timeline.PrevBatch = types.StreamPosition(1).String() // Use the short form of batch token for prev_batch
jr.Timeline.PrevBatch = "1"
} }
jr.Timeline.Events = gomatrixserverlib.ToClientEvents(recentEvents, gomatrixserverlib.FormatSync) jr.Timeline.Events = gomatrixserverlib.ToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
jr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true jr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true
@ -460,10 +614,12 @@ func (d *SyncServerDatabase) addRoomDeltaToResponse(
// TODO: recentEvents may contain events that this user is not allowed to see because they are // TODO: recentEvents may contain events that this user is not allowed to see because they are
// no longer in the room. // no longer in the room.
lr := types.NewLeaveResponse() lr := types.NewLeaveResponse()
if prevBatch := recentStreamEvents[0].streamPosition - 1; prevBatch > 0 { if prevPDUPos := recentStreamEvents[0].streamPosition - 1; prevPDUPos > 0 {
lr.Timeline.PrevBatch = types.StreamPosition(prevBatch).String() // Use the short form of batch token for prev_batch
lr.Timeline.PrevBatch = strconv.FormatInt(prevPDUPos, 10)
} else { } else {
lr.Timeline.PrevBatch = types.StreamPosition(1).String() // Use the short form of batch token for prev_batch
lr.Timeline.PrevBatch = "1"
} }
lr.Timeline.Events = gomatrixserverlib.ToClientEvents(recentEvents, gomatrixserverlib.FormatSync) lr.Timeline.Events = gomatrixserverlib.ToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
lr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true lr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true
@ -476,7 +632,7 @@ func (d *SyncServerDatabase) addRoomDeltaToResponse(
// fetchStateEvents converts the set of event IDs into a set of events. It will fetch any which are missing from the database. // fetchStateEvents converts the set of event IDs into a set of events. It will fetch any which are missing from the database.
// Returns a map of room ID to list of events. // Returns a map of room ID to list of events.
func (d *SyncServerDatabase) fetchStateEvents( func (d *SyncServerDatasource) fetchStateEvents(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
roomIDToEventIDSet map[string]map[string]bool, roomIDToEventIDSet map[string]map[string]bool,
eventIDToEvent map[string]streamEvent, eventIDToEvent map[string]streamEvent,
@ -521,7 +677,7 @@ func (d *SyncServerDatabase) fetchStateEvents(
return stateBetween, nil return stateBetween, nil
} }
func (d *SyncServerDatabase) fetchMissingStateEvents( func (d *SyncServerDatasource) fetchMissingStateEvents(
ctx context.Context, txn *sql.Tx, eventIDs []string, ctx context.Context, txn *sql.Tx, eventIDs []string,
) ([]streamEvent, error) { ) ([]streamEvent, error) {
// Fetch from the events table first so we pick up the stream ID for the // Fetch from the events table first so we pick up the stream ID for the
@ -560,9 +716,9 @@ func (d *SyncServerDatabase) fetchMissingStateEvents(
return events, nil return events, nil
} }
func (d *SyncServerDatabase) getStateDeltas( func (d *SyncServerDatasource) getStateDeltas(
ctx context.Context, device *authtypes.Device, txn *sql.Tx, ctx context.Context, device *authtypes.Device, txn *sql.Tx,
fromPos, toPos types.StreamPosition, userID string, fromPos, toPos int64, userID string,
) ([]stateDelta, error) { ) ([]stateDelta, error) {
// Implement membership change algorithm: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L821 // Implement membership change algorithm: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L821
// - Get membership list changes for this user in this sync response // - Get membership list changes for this user in this sync response
@ -601,7 +757,7 @@ func (d *SyncServerDatabase) getStateDeltas(
} }
s := make([]streamEvent, len(allState)) s := make([]streamEvent, len(allState))
for i := 0; i < len(s); i++ { for i := 0; i < len(s); i++ {
s[i] = streamEvent{Event: allState[i], streamPosition: types.StreamPosition(0)} s[i] = streamEvent{Event: allState[i], streamPosition: 0}
} }
state[roomID] = s state[roomID] = s
continue // we'll add this room in when we do joined rooms continue // we'll add this room in when we do joined rooms

View file

@ -26,7 +26,7 @@ import (
) )
// Notifier will wake up sleeping requests when there is some new data. // Notifier will wake up sleeping requests when there is some new data.
// It does not tell requests what that data is, only the stream position which // It does not tell requests what that data is, only the sync position which
// they can use to get at it. This is done to prevent races whereby we tell the caller // they can use to get at it. This is done to prevent races whereby we tell the caller
// the event, but the token has already advanced by the time they fetch it, resulting // the event, but the token has already advanced by the time they fetch it, resulting
// in missed events. // in missed events.
@ -35,18 +35,18 @@ type Notifier struct {
roomIDToJoinedUsers map[string]userIDSet roomIDToJoinedUsers map[string]userIDSet
// Protects currPos and userStreams. // Protects currPos and userStreams.
streamLock *sync.Mutex streamLock *sync.Mutex
// The latest sync stream position // The latest sync position
currPos types.StreamPosition currPos types.SyncPosition
// A map of user_id => UserStream which can be used to wake a given user's /sync request. // A map of user_id => UserStream which can be used to wake a given user's /sync request.
userStreams map[string]*UserStream userStreams map[string]*UserStream
// The last time we cleaned out stale entries from the userStreams map // The last time we cleaned out stale entries from the userStreams map
lastCleanUpTime time.Time lastCleanUpTime time.Time
} }
// NewNotifier creates a new notifier set to the given stream position. // NewNotifier creates a new notifier set to the given sync position.
// In order for this to be of any use, the Notifier needs to be told all rooms and // In order for this to be of any use, the Notifier needs to be told all rooms and
// the joined users within each of them by calling Notifier.Load(*storage.SyncServerDatabase). // the joined users within each of them by calling Notifier.Load(*storage.SyncServerDatabase).
func NewNotifier(pos types.StreamPosition) *Notifier { func NewNotifier(pos types.SyncPosition) *Notifier {
return &Notifier{ return &Notifier{
currPos: pos, currPos: pos,
roomIDToJoinedUsers: make(map[string]userIDSet), roomIDToJoinedUsers: make(map[string]userIDSet),
@ -58,20 +58,30 @@ func NewNotifier(pos types.StreamPosition) *Notifier {
// OnNewEvent is called when a new event is received from the room server. Must only be // OnNewEvent is called when a new event is received from the room server. Must only be
// called from a single goroutine, to avoid races between updates which could set the // called from a single goroutine, to avoid races between updates which could set the
// current position in the stream incorrectly. // current sync position incorrectly.
// Can be called either with a *gomatrixserverlib.Event, or with an user ID // Chooses which user sync streams to update by a provided *gomatrixserverlib.Event
func (n *Notifier) OnNewEvent(ev *gomatrixserverlib.Event, userID string, pos types.StreamPosition) { // (based on the users in the event's room),
// a roomID directly, or a list of user IDs, prioritised by parameter ordering.
// posUpdate contains the latest position(s) for one or more types of events.
// If a position in posUpdate is 0, it means no updates are available of that type.
// Typically a consumer supplies a posUpdate with the latest sync position for the
// event type it handles, leaving other fields as 0.
func (n *Notifier) OnNewEvent(
ev *gomatrixserverlib.Event, roomID string, userIDs []string,
posUpdate types.SyncPosition,
) {
// update the current position then notify relevant /sync streams. // update the current position then notify relevant /sync streams.
// This needs to be done PRIOR to waking up users as they will read this value. // This needs to be done PRIOR to waking up users as they will read this value.
n.streamLock.Lock() n.streamLock.Lock()
defer n.streamLock.Unlock() defer n.streamLock.Unlock()
n.currPos = pos latestPos := n.currPos.WithUpdates(posUpdate)
n.currPos = latestPos
n.removeEmptyUserStreams() n.removeEmptyUserStreams()
if ev != nil { if ev != nil {
// Map this event's room_id to a list of joined users, and wake them up. // Map this event's room_id to a list of joined users, and wake them up.
userIDs := n.joinedUsers(ev.RoomID()) usersToNotify := n.joinedUsers(ev.RoomID())
// If this is an invite, also add in the invitee to this list. // If this is an invite, also add in the invitee to this list.
if ev.Type() == "m.room.member" && ev.StateKey() != nil { if ev.Type() == "m.room.member" && ev.StateKey() != nil {
targetUserID := *ev.StateKey() targetUserID := *ev.StateKey()
@ -84,11 +94,11 @@ func (n *Notifier) OnNewEvent(ev *gomatrixserverlib.Event, userID string, pos ty
// Keep the joined user map up-to-date // Keep the joined user map up-to-date
switch membership { switch membership {
case "invite": case "invite":
userIDs = append(userIDs, targetUserID) usersToNotify = append(usersToNotify, targetUserID)
case "join": case "join":
// Manually append the new user's ID so they get notified // Manually append the new user's ID so they get notified
// along all members in the room // along all members in the room
userIDs = append(userIDs, targetUserID) usersToNotify = append(usersToNotify, targetUserID)
n.addJoinedUser(ev.RoomID(), targetUserID) n.addJoinedUser(ev.RoomID(), targetUserID)
case "leave": case "leave":
fallthrough fallthrough
@ -98,11 +108,15 @@ func (n *Notifier) OnNewEvent(ev *gomatrixserverlib.Event, userID string, pos ty
} }
} }
for _, toNotifyUserID := range userIDs { n.wakeupUsers(usersToNotify, latestPos)
n.wakeupUser(toNotifyUserID, pos) } else if roomID != "" {
} n.wakeupUsers(n.joinedUsers(roomID), latestPos)
} else if len(userID) > 0 { } else if len(userIDs) > 0 {
n.wakeupUser(userID, pos) n.wakeupUsers(userIDs, latestPos)
} else {
log.WithFields(log.Fields{
"posUpdate": posUpdate.String,
}).Warn("Notifier.OnNewEvent called but caller supplied no user to wake up")
} }
} }
@ -127,7 +141,7 @@ func (n *Notifier) GetListener(req syncRequest) UserStreamListener {
} }
// Load the membership states required to notify users correctly. // Load the membership states required to notify users correctly.
func (n *Notifier) Load(ctx context.Context, db *storage.SyncServerDatabase) error { func (n *Notifier) Load(ctx context.Context, db *storage.SyncServerDatasource) error {
roomToUsers, err := db.AllJoinedUsersInRooms(ctx) roomToUsers, err := db.AllJoinedUsersInRooms(ctx)
if err != nil { if err != nil {
return err return err
@ -136,8 +150,11 @@ func (n *Notifier) Load(ctx context.Context, db *storage.SyncServerDatabase) err
return nil return nil
} }
// CurrentPosition returns the current stream position // CurrentPosition returns the current sync position
func (n *Notifier) CurrentPosition() types.StreamPosition { func (n *Notifier) CurrentPosition() types.SyncPosition {
n.streamLock.Lock()
defer n.streamLock.Unlock()
return n.currPos return n.currPos
} }
@ -156,12 +173,13 @@ func (n *Notifier) setUsersJoinedToRooms(roomIDToUserIDs map[string][]string) {
} }
} }
func (n *Notifier) wakeupUser(userID string, newPos types.StreamPosition) { func (n *Notifier) wakeupUsers(userIDs []string, newPos types.SyncPosition) {
stream := n.fetchUserStream(userID, false) for _, userID := range userIDs {
if stream == nil { stream := n.fetchUserStream(userID, false)
return if stream != nil {
stream.Broadcast(newPos) // wake up all goroutines Wait()ing on this stream
}
} }
stream.Broadcast(newPos) // wakeup all goroutines Wait()ing on this stream
} }
// fetchUserStream retrieves a stream unique to the given user. If makeIfNotExists is true, // fetchUserStream retrieves a stream unique to the given user. If makeIfNotExists is true,

View file

@ -32,19 +32,40 @@ var (
randomMessageEvent gomatrixserverlib.Event randomMessageEvent gomatrixserverlib.Event
aliceInviteBobEvent gomatrixserverlib.Event aliceInviteBobEvent gomatrixserverlib.Event
bobLeaveEvent gomatrixserverlib.Event bobLeaveEvent gomatrixserverlib.Event
syncPositionVeryOld types.SyncPosition
syncPositionBefore types.SyncPosition
syncPositionAfter types.SyncPosition
syncPositionNewEDU types.SyncPosition
syncPositionAfter2 types.SyncPosition
) )
var ( var (
streamPositionVeryOld = types.StreamPosition(5) roomID = "!test:localhost"
streamPositionBefore = types.StreamPosition(11) alice = "@alice:localhost"
streamPositionAfter = types.StreamPosition(12) bob = "@bob:localhost"
streamPositionAfter2 = types.StreamPosition(13)
roomID = "!test:localhost"
alice = "@alice:localhost"
bob = "@bob:localhost"
) )
func init() { func init() {
baseSyncPos := types.SyncPosition{
PDUPosition: 0,
TypingPosition: 0,
}
syncPositionVeryOld = baseSyncPos
syncPositionVeryOld.PDUPosition = 5
syncPositionBefore = baseSyncPos
syncPositionBefore.PDUPosition = 11
syncPositionAfter = baseSyncPos
syncPositionAfter.PDUPosition = 12
syncPositionNewEDU = syncPositionAfter
syncPositionNewEDU.TypingPosition = 1
syncPositionAfter2 = baseSyncPos
syncPositionAfter2.PDUPosition = 13
var err error var err error
randomMessageEvent, err = gomatrixserverlib.NewEventFromTrustedJSON([]byte(`{ randomMessageEvent, err = gomatrixserverlib.NewEventFromTrustedJSON([]byte(`{
"type": "m.room.message", "type": "m.room.message",
@ -92,19 +113,19 @@ func init() {
// Test that the current position is returned if a request is already behind. // Test that the current position is returned if a request is already behind.
func TestImmediateNotification(t *testing.T) { func TestImmediateNotification(t *testing.T) {
n := NewNotifier(streamPositionBefore) n := NewNotifier(syncPositionBefore)
pos, err := waitForEvents(n, newTestSyncRequest(alice, streamPositionVeryOld)) pos, err := waitForEvents(n, newTestSyncRequest(alice, syncPositionVeryOld))
if err != nil { if err != nil {
t.Fatalf("TestImmediateNotification error: %s", err) t.Fatalf("TestImmediateNotification error: %s", err)
} }
if pos != streamPositionBefore { if pos != syncPositionBefore {
t.Fatalf("TestImmediateNotification want %d, got %d", streamPositionBefore, pos) t.Fatalf("TestImmediateNotification want %d, got %d", syncPositionBefore, pos)
} }
} }
// Test that new events to a joined room unblocks the request. // Test that new events to a joined room unblocks the request.
func TestNewEventAndJoinedToRoom(t *testing.T) { func TestNewEventAndJoinedToRoom(t *testing.T) {
n := NewNotifier(streamPositionBefore) n := NewNotifier(syncPositionBefore)
n.setUsersJoinedToRooms(map[string][]string{ n.setUsersJoinedToRooms(map[string][]string{
roomID: {alice, bob}, roomID: {alice, bob},
}) })
@ -112,12 +133,12 @@ func TestNewEventAndJoinedToRoom(t *testing.T) {
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(1) wg.Add(1)
go func() { go func() {
pos, err := waitForEvents(n, newTestSyncRequest(bob, streamPositionBefore)) pos, err := waitForEvents(n, newTestSyncRequest(bob, syncPositionBefore))
if err != nil { if err != nil {
t.Errorf("TestNewEventAndJoinedToRoom error: %s", err) t.Errorf("TestNewEventAndJoinedToRoom error: %s", err)
} }
if pos != streamPositionAfter { if pos != syncPositionAfter {
t.Errorf("TestNewEventAndJoinedToRoom want %d, got %d", streamPositionAfter, pos) t.Errorf("TestNewEventAndJoinedToRoom want %d, got %d", syncPositionAfter, pos)
} }
wg.Done() wg.Done()
}() }()
@ -125,14 +146,14 @@ func TestNewEventAndJoinedToRoom(t *testing.T) {
stream := n.fetchUserStream(bob, true) stream := n.fetchUserStream(bob, true)
waitForBlocking(stream, 1) waitForBlocking(stream, 1)
n.OnNewEvent(&randomMessageEvent, "", streamPositionAfter) n.OnNewEvent(&randomMessageEvent, "", nil, syncPositionAfter)
wg.Wait() wg.Wait()
} }
// Test that an invite unblocks the request // Test that an invite unblocks the request
func TestNewInviteEventForUser(t *testing.T) { func TestNewInviteEventForUser(t *testing.T) {
n := NewNotifier(streamPositionBefore) n := NewNotifier(syncPositionBefore)
n.setUsersJoinedToRooms(map[string][]string{ n.setUsersJoinedToRooms(map[string][]string{
roomID: {alice, bob}, roomID: {alice, bob},
}) })
@ -140,12 +161,12 @@ func TestNewInviteEventForUser(t *testing.T) {
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(1) wg.Add(1)
go func() { go func() {
pos, err := waitForEvents(n, newTestSyncRequest(bob, streamPositionBefore)) pos, err := waitForEvents(n, newTestSyncRequest(bob, syncPositionBefore))
if err != nil { if err != nil {
t.Errorf("TestNewInviteEventForUser error: %s", err) t.Errorf("TestNewInviteEventForUser error: %s", err)
} }
if pos != streamPositionAfter { if pos != syncPositionAfter {
t.Errorf("TestNewInviteEventForUser want %d, got %d", streamPositionAfter, pos) t.Errorf("TestNewInviteEventForUser want %d, got %d", syncPositionAfter, pos)
} }
wg.Done() wg.Done()
}() }()
@ -153,14 +174,42 @@ func TestNewInviteEventForUser(t *testing.T) {
stream := n.fetchUserStream(bob, true) stream := n.fetchUserStream(bob, true)
waitForBlocking(stream, 1) waitForBlocking(stream, 1)
n.OnNewEvent(&aliceInviteBobEvent, "", streamPositionAfter) n.OnNewEvent(&aliceInviteBobEvent, "", nil, syncPositionAfter)
wg.Wait()
}
// Test an EDU-only update wakes up the request.
func TestEDUWakeup(t *testing.T) {
n := NewNotifier(syncPositionAfter)
n.setUsersJoinedToRooms(map[string][]string{
roomID: {alice, bob},
})
var wg sync.WaitGroup
wg.Add(1)
go func() {
pos, err := waitForEvents(n, newTestSyncRequest(bob, syncPositionAfter))
if err != nil {
t.Errorf("TestNewInviteEventForUser error: %s", err)
}
if pos != syncPositionNewEDU {
t.Errorf("TestNewInviteEventForUser want %d, got %d", syncPositionNewEDU, pos)
}
wg.Done()
}()
stream := n.fetchUserStream(bob, true)
waitForBlocking(stream, 1)
n.OnNewEvent(&aliceInviteBobEvent, "", nil, syncPositionNewEDU)
wg.Wait() wg.Wait()
} }
// Test that all blocked requests get woken up on a new event. // Test that all blocked requests get woken up on a new event.
func TestMultipleRequestWakeup(t *testing.T) { func TestMultipleRequestWakeup(t *testing.T) {
n := NewNotifier(streamPositionBefore) n := NewNotifier(syncPositionBefore)
n.setUsersJoinedToRooms(map[string][]string{ n.setUsersJoinedToRooms(map[string][]string{
roomID: {alice, bob}, roomID: {alice, bob},
}) })
@ -168,12 +217,12 @@ func TestMultipleRequestWakeup(t *testing.T) {
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(3) wg.Add(3)
poll := func() { poll := func() {
pos, err := waitForEvents(n, newTestSyncRequest(bob, streamPositionBefore)) pos, err := waitForEvents(n, newTestSyncRequest(bob, syncPositionBefore))
if err != nil { if err != nil {
t.Errorf("TestMultipleRequestWakeup error: %s", err) t.Errorf("TestMultipleRequestWakeup error: %s", err)
} }
if pos != streamPositionAfter { if pos != syncPositionAfter {
t.Errorf("TestMultipleRequestWakeup want %d, got %d", streamPositionAfter, pos) t.Errorf("TestMultipleRequestWakeup want %d, got %d", syncPositionAfter, pos)
} }
wg.Done() wg.Done()
} }
@ -184,7 +233,7 @@ func TestMultipleRequestWakeup(t *testing.T) {
stream := n.fetchUserStream(bob, true) stream := n.fetchUserStream(bob, true)
waitForBlocking(stream, 3) waitForBlocking(stream, 3)
n.OnNewEvent(&randomMessageEvent, "", streamPositionAfter) n.OnNewEvent(&randomMessageEvent, "", nil, syncPositionAfter)
wg.Wait() wg.Wait()
@ -198,7 +247,7 @@ func TestMultipleRequestWakeup(t *testing.T) {
func TestNewEventAndWasPreviouslyJoinedToRoom(t *testing.T) { func TestNewEventAndWasPreviouslyJoinedToRoom(t *testing.T) {
// listen as bob. Make bob leave room. Make alice send event to room. // listen as bob. Make bob leave room. Make alice send event to room.
// Make sure alice gets woken up only and not bob as well. // Make sure alice gets woken up only and not bob as well.
n := NewNotifier(streamPositionBefore) n := NewNotifier(syncPositionBefore)
n.setUsersJoinedToRooms(map[string][]string{ n.setUsersJoinedToRooms(map[string][]string{
roomID: {alice, bob}, roomID: {alice, bob},
}) })
@ -208,18 +257,18 @@ func TestNewEventAndWasPreviouslyJoinedToRoom(t *testing.T) {
// Make bob leave the room // Make bob leave the room
leaveWG.Add(1) leaveWG.Add(1)
go func() { go func() {
pos, err := waitForEvents(n, newTestSyncRequest(bob, streamPositionBefore)) pos, err := waitForEvents(n, newTestSyncRequest(bob, syncPositionBefore))
if err != nil { if err != nil {
t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom error: %s", err) t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom error: %s", err)
} }
if pos != streamPositionAfter { if pos != syncPositionAfter {
t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom want %d, got %d", streamPositionAfter, pos) t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom want %d, got %d", syncPositionAfter, pos)
} }
leaveWG.Done() leaveWG.Done()
}() }()
bobStream := n.fetchUserStream(bob, true) bobStream := n.fetchUserStream(bob, true)
waitForBlocking(bobStream, 1) waitForBlocking(bobStream, 1)
n.OnNewEvent(&bobLeaveEvent, "", streamPositionAfter) n.OnNewEvent(&bobLeaveEvent, "", nil, syncPositionAfter)
leaveWG.Wait() leaveWG.Wait()
// send an event into the room. Make sure alice gets it. Bob should not. // send an event into the room. Make sure alice gets it. Bob should not.
@ -227,19 +276,19 @@ func TestNewEventAndWasPreviouslyJoinedToRoom(t *testing.T) {
aliceStream := n.fetchUserStream(alice, true) aliceStream := n.fetchUserStream(alice, true)
aliceWG.Add(1) aliceWG.Add(1)
go func() { go func() {
pos, err := waitForEvents(n, newTestSyncRequest(alice, streamPositionAfter)) pos, err := waitForEvents(n, newTestSyncRequest(alice, syncPositionAfter))
if err != nil { if err != nil {
t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom error: %s", err) t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom error: %s", err)
} }
if pos != streamPositionAfter2 { if pos != syncPositionAfter2 {
t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom want %d, got %d", streamPositionAfter2, pos) t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom want %d, got %d", syncPositionAfter2, pos)
} }
aliceWG.Done() aliceWG.Done()
}() }()
go func() { go func() {
// this should timeout with an error (but the main goroutine won't wait for the timeout explicitly) // this should timeout with an error (but the main goroutine won't wait for the timeout explicitly)
_, err := waitForEvents(n, newTestSyncRequest(bob, streamPositionAfter)) _, err := waitForEvents(n, newTestSyncRequest(bob, syncPositionAfter))
if err == nil { if err == nil {
t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom expect error but got nil") t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom expect error but got nil")
} }
@ -248,7 +297,7 @@ func TestNewEventAndWasPreviouslyJoinedToRoom(t *testing.T) {
waitForBlocking(aliceStream, 1) waitForBlocking(aliceStream, 1)
waitForBlocking(bobStream, 1) waitForBlocking(bobStream, 1)
n.OnNewEvent(&randomMessageEvent, "", streamPositionAfter2) n.OnNewEvent(&randomMessageEvent, "", nil, syncPositionAfter2)
aliceWG.Wait() aliceWG.Wait()
// it's possible that at this point alice has been informed and bob is about to be informed, so wait // it's possible that at this point alice has been informed and bob is about to be informed, so wait
@ -256,18 +305,17 @@ func TestNewEventAndWasPreviouslyJoinedToRoom(t *testing.T) {
time.Sleep(1 * time.Millisecond) time.Sleep(1 * time.Millisecond)
} }
// same as Notifier.WaitForEvents but with a timeout. func waitForEvents(n *Notifier, req syncRequest) (types.SyncPosition, error) {
func waitForEvents(n *Notifier, req syncRequest) (types.StreamPosition, error) {
listener := n.GetListener(req) listener := n.GetListener(req)
defer listener.Close() defer listener.Close()
select { select {
case <-time.After(5 * time.Second): case <-time.After(5 * time.Second):
return types.StreamPosition(0), fmt.Errorf( return types.SyncPosition{}, fmt.Errorf(
"waitForEvents timed out waiting for %s (pos=%d)", req.device.UserID, req.since, "waitForEvents timed out waiting for %s (pos=%d)", req.device.UserID, req.since,
) )
case <-listener.GetNotifyChannel(*req.since): case <-listener.GetNotifyChannel(*req.since):
p := listener.GetStreamPosition() p := listener.GetSyncPosition()
return p, nil return p, nil
} }
} }
@ -280,7 +328,7 @@ func waitForBlocking(s *UserStream, numBlocking uint) {
} }
} }
func newTestSyncRequest(userID string, since types.StreamPosition) syncRequest { func newTestSyncRequest(userID string, since types.SyncPosition) syncRequest {
return syncRequest{ return syncRequest{
device: authtypes.Device{UserID: userID}, device: authtypes.Device{UserID: userID},
timeout: 1 * time.Minute, timeout: 1 * time.Minute,

View file

@ -16,8 +16,10 @@ package sync
import ( import (
"context" "context"
"errors"
"net/http" "net/http"
"strconv" "strconv"
"strings"
"time" "time"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
@ -36,7 +38,7 @@ type syncRequest struct {
device authtypes.Device device authtypes.Device
limit int limit int
timeout time.Duration timeout time.Duration
since *types.StreamPosition // nil means that no since token was supplied since *types.SyncPosition // nil means that no since token was supplied
wantFullState bool wantFullState bool
log *log.Entry log *log.Entry
} }
@ -73,15 +75,41 @@ func getTimeout(timeoutMS string) time.Duration {
} }
// getSyncStreamPosition tries to parse a 'since' token taken from the API to a // getSyncStreamPosition tries to parse a 'since' token taken from the API to a
// stream position. If the string is empty then (nil, nil) is returned. // types.SyncPosition. If the string is empty then (nil, nil) is returned.
func getSyncStreamPosition(since string) (*types.StreamPosition, error) { // There are two forms of tokens: The full length form containing all PDU and EDU
// positions separated by "_", and the short form containing only the PDU
// position. Short form can be used for, e.g., `prev_batch` tokens.
func getSyncStreamPosition(since string) (*types.SyncPosition, error) {
if since == "" { if since == "" {
return nil, nil return nil, nil
} }
i, err := strconv.Atoi(since)
if err != nil { posStrings := strings.Split(since, "_")
return nil, err if len(posStrings) != 2 && len(posStrings) != 1 {
// A token can either be full length or short (PDU-only).
return nil, errors.New("malformed batch token")
}
positions := make([]int64, len(posStrings))
for i, posString := range posStrings {
pos, err := strconv.ParseInt(posString, 10, 64)
if err != nil {
return nil, err
}
positions[i] = pos
}
if len(positions) == 2 {
// Full length token; construct SyncPosition with every entry in
// `positions`. These entries must have the same order with the fields
// in struct SyncPosition, so we disable the govet check below.
return &types.SyncPosition{ //nolint:govet
positions[0], positions[1],
}, nil
} else {
// Token with PDU position only
return &types.SyncPosition{
PDUPosition: positions[0],
}, nil
} }
token := types.StreamPosition(i)
return &token, nil
} }

View file

@ -31,13 +31,13 @@ import (
// RequestPool manages HTTP long-poll connections for /sync // RequestPool manages HTTP long-poll connections for /sync
type RequestPool struct { type RequestPool struct {
db *storage.SyncServerDatabase db *storage.SyncServerDatasource
accountDB *accounts.Database accountDB *accounts.Database
notifier *Notifier notifier *Notifier
} }
// NewRequestPool makes a new RequestPool // NewRequestPool makes a new RequestPool
func NewRequestPool(db *storage.SyncServerDatabase, n *Notifier, adb *accounts.Database) *RequestPool { func NewRequestPool(db *storage.SyncServerDatasource, n *Notifier, adb *accounts.Database) *RequestPool {
return &RequestPool{db, adb, n} return &RequestPool{db, adb, n}
} }
@ -92,11 +92,13 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype
// respond with, so we skip the return an go back to waiting for content to // respond with, so we skip the return an go back to waiting for content to
// be sent down or the request timing out. // be sent down or the request timing out.
var hasTimedOut bool var hasTimedOut bool
sincePos := *syncReq.since
for { for {
select { select {
// Wait for notifier to wake us up // Wait for notifier to wake us up
case <-userStreamListener.GetNotifyChannel(currPos): case <-userStreamListener.GetNotifyChannel(sincePos):
currPos = userStreamListener.GetStreamPosition() currPos = userStreamListener.GetSyncPosition()
sincePos = currPos
// Or for timeout to expire // Or for timeout to expire
case <-timer.C: case <-timer.C:
// We just need to ensure we get out of the select after reaching the // We just need to ensure we get out of the select after reaching the
@ -128,24 +130,24 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype
} }
} }
func (rp *RequestPool) currentSyncForUser(req syncRequest, currentPos types.StreamPosition) (res *types.Response, err error) { func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.SyncPosition) (res *types.Response, err error) {
// TODO: handle ignored users // TODO: handle ignored users
if req.since == nil { if req.since == nil {
res, err = rp.db.CompleteSync(req.ctx, req.device.UserID, req.limit) res, err = rp.db.CompleteSync(req.ctx, req.device.UserID, req.limit)
} else { } else {
res, err = rp.db.IncrementalSync(req.ctx, req.device, *req.since, currentPos, req.limit) res, err = rp.db.IncrementalSync(req.ctx, req.device, *req.since, latestPos, req.limit)
} }
if err != nil { if err != nil {
return return
} }
res, err = rp.appendAccountData(res, req.device.UserID, req, currentPos) res, err = rp.appendAccountData(res, req.device.UserID, req, latestPos.PDUPosition)
return return
} }
func (rp *RequestPool) appendAccountData( func (rp *RequestPool) appendAccountData(
data *types.Response, userID string, req syncRequest, currentPos types.StreamPosition, data *types.Response, userID string, req syncRequest, currentPos int64,
) (*types.Response, error) { ) (*types.Response, error) {
// TODO: Account data doesn't have a sync position of its own, meaning that // TODO: Account data doesn't have a sync position of its own, meaning that
// account data might be sent multiple time to the client if multiple account // account data might be sent multiple time to the client if multiple account
@ -179,7 +181,7 @@ func (rp *RequestPool) appendAccountData(
} }
// Sync is not initial, get all account data since the latest sync // Sync is not initial, get all account data since the latest sync
dataTypes, err := rp.db.GetAccountDataInRange(req.ctx, userID, *req.since, currentPos) dataTypes, err := rp.db.GetAccountDataInRange(req.ctx, userID, req.since.PDUPosition, currentPos)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -34,8 +34,8 @@ type UserStream struct {
lock sync.Mutex lock sync.Mutex
// Closed when there is an update. // Closed when there is an update.
signalChannel chan struct{} signalChannel chan struct{}
// The last stream position that there may have been an update for the suser // The last sync position that there may have been an update for the user
pos types.StreamPosition pos types.SyncPosition
// The last time when we had some listeners waiting // The last time when we had some listeners waiting
timeOfLastChannel time.Time timeOfLastChannel time.Time
// The number of listeners waiting // The number of listeners waiting
@ -51,7 +51,7 @@ type UserStreamListener struct {
} }
// NewUserStream creates a new user stream // NewUserStream creates a new user stream
func NewUserStream(userID string, currPos types.StreamPosition) *UserStream { func NewUserStream(userID string, currPos types.SyncPosition) *UserStream {
return &UserStream{ return &UserStream{
UserID: userID, UserID: userID,
timeOfLastChannel: time.Now(), timeOfLastChannel: time.Now(),
@ -84,8 +84,8 @@ func (s *UserStream) GetListener(ctx context.Context) UserStreamListener {
return listener return listener
} }
// Broadcast a new stream position for this user. // Broadcast a new sync position for this user.
func (s *UserStream) Broadcast(pos types.StreamPosition) { func (s *UserStream) Broadcast(pos types.SyncPosition) {
s.lock.Lock() s.lock.Lock()
defer s.lock.Unlock() defer s.lock.Unlock()
@ -118,9 +118,9 @@ func (s *UserStream) TimeOfLastNonEmpty() time.Time {
return s.timeOfLastChannel return s.timeOfLastChannel
} }
// GetStreamPosition returns last stream position which the UserStream was // GetStreamPosition returns last sync position which the UserStream was
// notified about // notified about
func (s *UserStreamListener) GetStreamPosition() types.StreamPosition { func (s *UserStreamListener) GetSyncPosition() types.SyncPosition {
s.userStream.lock.Lock() s.userStream.lock.Lock()
defer s.userStream.lock.Unlock() defer s.userStream.lock.Unlock()
@ -132,11 +132,11 @@ func (s *UserStreamListener) GetStreamPosition() types.StreamPosition {
// sincePos specifies from which point we want to be notified about. If there // sincePos specifies from which point we want to be notified about. If there
// has already been an update after sincePos we'll return a closed channel // has already been an update after sincePos we'll return a closed channel
// immediately. // immediately.
func (s *UserStreamListener) GetNotifyChannel(sincePos types.StreamPosition) <-chan struct{} { func (s *UserStreamListener) GetNotifyChannel(sincePos types.SyncPosition) <-chan struct{} {
s.userStream.lock.Lock() s.userStream.lock.Lock()
defer s.userStream.lock.Unlock() defer s.userStream.lock.Unlock()
if sincePos < s.userStream.pos { if s.userStream.pos.IsAfter(sincePos) {
// If the listener is behind, i.e. missed a potential update, then we // If the listener is behind, i.e. missed a potential update, then we
// want them to wake up immediately. We do this by returning a new // want them to wake up immediately. We do this by returning a new
// closed stream, which returns immediately when selected. // closed stream, which returns immediately when selected.

View file

@ -28,7 +28,6 @@ import (
"github.com/matrix-org/dendrite/syncapi/routing" "github.com/matrix-org/dendrite/syncapi/routing"
"github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/sync" "github.com/matrix-org/dendrite/syncapi/sync"
"github.com/matrix-org/dendrite/syncapi/types"
) )
// SetupSyncAPIComponent sets up and registers HTTP handlers for the SyncAPI // SetupSyncAPIComponent sets up and registers HTTP handlers for the SyncAPI
@ -39,17 +38,17 @@ func SetupSyncAPIComponent(
accountsDB *accounts.Database, accountsDB *accounts.Database,
queryAPI api.RoomserverQueryAPI, queryAPI api.RoomserverQueryAPI,
) { ) {
syncDB, err := storage.NewSyncServerDatabase(string(base.Cfg.Database.SyncAPI)) syncDB, err := storage.NewSyncServerDatasource(string(base.Cfg.Database.SyncAPI))
if err != nil { if err != nil {
logrus.WithError(err).Panicf("failed to connect to sync db") logrus.WithError(err).Panicf("failed to connect to sync db")
} }
pos, err := syncDB.SyncStreamPosition(context.Background()) pos, err := syncDB.SyncPosition(context.Background())
if err != nil { if err != nil {
logrus.WithError(err).Panicf("failed to get stream position") logrus.WithError(err).Panicf("failed to get sync position")
} }
notifier := sync.NewNotifier(types.StreamPosition(pos)) notifier := sync.NewNotifier(pos)
err = notifier.Load(context.Background(), syncDB) err = notifier.Load(context.Background(), syncDB)
if err != nil { if err != nil {
logrus.WithError(err).Panicf("failed to start notifier") logrus.WithError(err).Panicf("failed to start notifier")
@ -71,5 +70,12 @@ func SetupSyncAPIComponent(
logrus.WithError(err).Panicf("failed to start client data consumer") logrus.WithError(err).Panicf("failed to start client data consumer")
} }
typingConsumer := consumers.NewOutputTypingEventConsumer(
base.Cfg, base.KafkaConsumer, notifier, syncDB,
)
if err = typingConsumer.Start(); err != nil {
logrus.WithError(err).Panicf("failed to start typing server consumer")
}
routing.Setup(base.APIMux, requestPool, syncDB, deviceDB) routing.Setup(base.APIMux, requestPool, syncDB, deviceDB)
} }

View file

@ -21,12 +21,38 @@ import (
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
// StreamPosition represents the offset in the sync stream a client is at. // SyncPosition contains the PDU and EDU stream sync positions for a client.
type StreamPosition int64 type SyncPosition struct {
// PDUPosition is the stream position for PDUs the client is at.
PDUPosition int64
// TypingPosition is the client's position for typing notifications.
TypingPosition int64
}
// String implements the Stringer interface. // String implements the Stringer interface.
func (sp StreamPosition) String() string { func (sp SyncPosition) String() string {
return strconv.FormatInt(int64(sp), 10) return strconv.FormatInt(sp.PDUPosition, 10) + "_" +
strconv.FormatInt(sp.TypingPosition, 10)
}
// IsAfter returns whether one SyncPosition refers to states newer than another SyncPosition.
func (sp SyncPosition) IsAfter(other SyncPosition) bool {
return sp.PDUPosition > other.PDUPosition ||
sp.TypingPosition > other.TypingPosition
}
// WithUpdates returns a copy of the SyncPosition with updates applied from another SyncPosition.
// If the latter SyncPosition contains a field that is not 0, it is considered an update,
// and its value will replace the corresponding value in the SyncPosition on which WithUpdates is called.
func (sp SyncPosition) WithUpdates(other SyncPosition) SyncPosition {
ret := sp
if other.PDUPosition != 0 {
ret.PDUPosition = other.PDUPosition
}
if other.TypingPosition != 0 {
ret.TypingPosition = other.TypingPosition
}
return ret
} }
// PrevEventRef represents a reference to a previous event in a state event upgrade // PrevEventRef represents a reference to a previous event in a state event upgrade
@ -53,11 +79,10 @@ type Response struct {
} }
// NewResponse creates an empty response with initialised maps. // NewResponse creates an empty response with initialised maps.
func NewResponse(pos StreamPosition) *Response { func NewResponse(pos SyncPosition) *Response {
res := Response{} res := Response{
// Make sure we send the next_batch as a string. We don't want to confuse clients by sending this NextBatch: pos.String(),
// as an integer even though (at the moment) it is. }
res.NextBatch = pos.String()
// Pre-initialise the maps. Synapse will return {} even if there are no rooms under a specific section, // Pre-initialise the maps. Synapse will return {} even if there are no rooms under a specific section,
// so let's do the same thing. Bonus: this means we can't get dreaded 'assignment to entry in nil map' errors. // so let's do the same thing. Bonus: this means we can't get dreaded 'assignment to entry in nil map' errors.
res.Rooms.Join = make(map[string]JoinResponse) res.Rooms.Join = make(map[string]JoinResponse)

View file

@ -42,6 +42,7 @@ POST /join/:room_alias can join a room
POST /join/:room_id can join a room POST /join/:room_id can join a room
POST /join/:room_id can join a room with custom content POST /join/:room_id can join a room with custom content
POST /join/:room_alias can join a room with custom content POST /join/:room_alias can join a room with custom content
POST /rooms/:room_id/join can join a room
POST /rooms/:room_id/leave can leave a room POST /rooms/:room_id/leave can leave a room
POST /rooms/:room_id/invite can send an invite POST /rooms/:room_id/invite can send an invite
POST /rooms/:room_id/ban can ban a user POST /rooms/:room_id/ban can ban a user
@ -143,4 +144,7 @@ Events come down the correct room
local user can join room with version 5 local user can join room with version 5
User can invite local user to room with version 5 User can invite local user to room with version 5
Inbound federation can receive room-join requests Inbound federation can receive room-join requests
Typing events appear in initial sync
Typing events appear in incremental sync
Typing events appear in gapped sync
Outbound federation can query profile data Outbound federation can query profile data

View file

@ -12,14 +12,17 @@
package api package api
import "time"
// OutputTypingEvent is an entry in typing server output kafka log. // OutputTypingEvent is an entry in typing server output kafka log.
// This contains the event with extra fields used to create 'm.typing' event // This contains the event with extra fields used to create 'm.typing' event
// in clientapi & federation. // in clientapi & federation.
type OutputTypingEvent struct { type OutputTypingEvent struct {
// The Event for the typing edu event. // The Event for the typing edu event.
Event TypingEvent `json:"event"` Event TypingEvent `json:"event"`
// Users typing in the room when the event was generated. // ExpireTime is the interval after which the user should no longer be
TypingUsers []string `json:"typing_users"` // considered typing. Only available if Event.Typing is true.
ExpireTime *time.Time
} }
// TypingEvent represents a matrix edu event of type 'm.typing'. // TypingEvent represents a matrix edu event of type 'm.typing'.

View file

@ -22,25 +22,66 @@ const defaultTypingTimeout = 10 * time.Second
// userSet is a map of user IDs to a timer, timer fires at expiry. // userSet is a map of user IDs to a timer, timer fires at expiry.
type userSet map[string]*time.Timer type userSet map[string]*time.Timer
// TimeoutCallbackFn is a function called right after the removal of a user
// from the typing user list due to timeout.
// latestSyncPosition is the typing sync position after the removal.
type TimeoutCallbackFn func(userID, roomID string, latestSyncPosition int64)
type roomData struct {
syncPosition int64
userSet userSet
}
// TypingCache maintains a list of users typing in each room. // TypingCache maintains a list of users typing in each room.
type TypingCache struct { type TypingCache struct {
sync.RWMutex sync.RWMutex
data map[string]userSet latestSyncPosition int64
data map[string]*roomData
timeoutCallback TimeoutCallbackFn
}
// Create a roomData with its sync position set to the latest sync position.
// Must only be called after locking the cache.
func (t *TypingCache) newRoomData() *roomData {
return &roomData{
syncPosition: t.latestSyncPosition,
userSet: make(userSet),
}
} }
// NewTypingCache returns a new TypingCache initialised for use. // NewTypingCache returns a new TypingCache initialised for use.
func NewTypingCache() *TypingCache { func NewTypingCache() *TypingCache {
return &TypingCache{data: make(map[string]userSet)} return &TypingCache{data: make(map[string]*roomData)}
}
// SetTimeoutCallback sets a callback function that is called right after
// a user is removed from the typing user list due to timeout.
func (t *TypingCache) SetTimeoutCallback(fn TimeoutCallbackFn) {
t.timeoutCallback = fn
} }
// GetTypingUsers returns the list of users typing in a room. // GetTypingUsers returns the list of users typing in a room.
func (t *TypingCache) GetTypingUsers(roomID string) (users []string) { func (t *TypingCache) GetTypingUsers(roomID string) []string {
users, _ := t.GetTypingUsersIfUpdatedAfter(roomID, 0)
// 0 should work above because the first position used will be 1.
return users
}
// GetTypingUsersIfUpdatedAfter returns all users typing in this room with
// updated == true if the typing sync position of the room is after the given
// position. Otherwise, returns an empty slice with updated == false.
func (t *TypingCache) GetTypingUsersIfUpdatedAfter(
roomID string, position int64,
) (users []string, updated bool) {
t.RLock() t.RLock()
usersMap, ok := t.data[roomID] defer t.RUnlock()
t.RUnlock()
if ok { roomData, ok := t.data[roomID]
users = make([]string, 0, len(usersMap)) if ok && roomData.syncPosition > position {
for userID := range usersMap { updated = true
userSet := roomData.userSet
users = make([]string, 0, len(userSet))
for userID := range userSet {
users = append(users, userID) users = append(users, userID)
} }
} }
@ -51,25 +92,41 @@ func (t *TypingCache) GetTypingUsers(roomID string) (users []string) {
// AddTypingUser sets an user as typing in a room. // AddTypingUser sets an user as typing in a room.
// expire is the time when the user typing should time out. // expire is the time when the user typing should time out.
// if expire is nil, defaultTypingTimeout is assumed. // if expire is nil, defaultTypingTimeout is assumed.
func (t *TypingCache) AddTypingUser(userID, roomID string, expire *time.Time) { // Returns the latest sync position for typing after update.
func (t *TypingCache) AddTypingUser(
userID, roomID string, expire *time.Time,
) int64 {
expireTime := getExpireTime(expire) expireTime := getExpireTime(expire)
if until := time.Until(expireTime); until > 0 { if until := time.Until(expireTime); until > 0 {
timer := time.AfterFunc(until, t.timeoutCallback(userID, roomID)) timer := time.AfterFunc(until, func() {
t.addUser(userID, roomID, timer) latestSyncPosition := t.RemoveUser(userID, roomID)
if t.timeoutCallback != nil {
t.timeoutCallback(userID, roomID, latestSyncPosition)
}
})
return t.addUser(userID, roomID, timer)
} }
return t.GetLatestSyncPosition()
} }
// addUser with mutex lock & replace the previous timer. // addUser with mutex lock & replace the previous timer.
func (t *TypingCache) addUser(userID, roomID string, expiryTimer *time.Timer) { // Returns the latest typing sync position after update.
func (t *TypingCache) addUser(
userID, roomID string, expiryTimer *time.Timer,
) int64 {
t.Lock() t.Lock()
defer t.Unlock() defer t.Unlock()
t.latestSyncPosition++
if t.data[roomID] == nil { if t.data[roomID] == nil {
t.data[roomID] = make(userSet) t.data[roomID] = t.newRoomData()
} else {
t.data[roomID].syncPosition = t.latestSyncPosition
} }
// Stop the timer to cancel the call to timeoutCallback // Stop the timer to cancel the call to timeoutCallback
if timer, ok := t.data[roomID][userID]; ok { if timer, ok := t.data[roomID].userSet[userID]; ok {
// It may happen that at this stage timer fires but now we have a lock on t. // It may happen that at this stage timer fires but now we have a lock on t.
// Hence the execution of timeoutCallback will happen after we unlock. // Hence the execution of timeoutCallback will happen after we unlock.
// So we may lose a typing state, though this event is highly unlikely. // So we may lose a typing state, though this event is highly unlikely.
@ -78,26 +135,40 @@ func (t *TypingCache) addUser(userID, roomID string, expiryTimer *time.Timer) {
timer.Stop() timer.Stop()
} }
t.data[roomID][userID] = expiryTimer t.data[roomID].userSet[userID] = expiryTimer
}
// Returns a function which is called after timeout happens. return t.latestSyncPosition
// This removes the user.
func (t *TypingCache) timeoutCallback(userID, roomID string) func() {
return func() {
t.RemoveUser(userID, roomID)
}
} }
// RemoveUser with mutex lock & stop the timer. // RemoveUser with mutex lock & stop the timer.
func (t *TypingCache) RemoveUser(userID, roomID string) { // Returns the latest sync position for typing after update.
func (t *TypingCache) RemoveUser(userID, roomID string) int64 {
t.Lock() t.Lock()
defer t.Unlock() defer t.Unlock()
if timer, ok := t.data[roomID][userID]; ok { roomData, ok := t.data[roomID]
timer.Stop() if !ok {
delete(t.data[roomID], userID) return t.latestSyncPosition
} }
timer, ok := roomData.userSet[userID]
if !ok {
return t.latestSyncPosition
}
timer.Stop()
delete(roomData.userSet, userID)
t.latestSyncPosition++
t.data[roomID].syncPosition = t.latestSyncPosition
return t.latestSyncPosition
}
func (t *TypingCache) GetLatestSyncPosition() int64 {
t.Lock()
defer t.Unlock()
return t.latestSyncPosition
} }
func getExpireTime(expire *time.Time) time.Time { func getExpireTime(expire *time.Time) time.Time {

View file

@ -57,15 +57,21 @@ func (t *TypingServerInputAPI) InputTypingEvent(
} }
func (t *TypingServerInputAPI) sendEvent(ite *api.InputTypingEvent) error { func (t *TypingServerInputAPI) sendEvent(ite *api.InputTypingEvent) error {
userIDs := t.Cache.GetTypingUsers(ite.RoomID)
ev := &api.TypingEvent{ ev := &api.TypingEvent{
Type: gomatrixserverlib.MTyping, Type: gomatrixserverlib.MTyping,
RoomID: ite.RoomID, RoomID: ite.RoomID,
UserID: ite.UserID, UserID: ite.UserID,
Typing: ite.Typing,
} }
ote := &api.OutputTypingEvent{ ote := &api.OutputTypingEvent{
Event: *ev, Event: *ev,
TypingUsers: userIDs, }
if ev.Typing {
expireTime := ite.OriginServerTS.Time().Add(
time.Duration(ite.Timeout) * time.Millisecond,
)
ote.ExpireTime = &expireTime
} }
eventJSON, err := json.Marshal(ote) eventJSON, err := json.Marshal(ote)