Set depth of events and whether they need to be federated.

Set the depth of each new event to one greater than the maximum depth
of it's direct ancestors.

Add a flag to each event passing through the roomserver that tells us
whether the event needs to be sent over federation.

We do this by passing the name of the server to send the event as.
We will need this capability if we want to support vhosting as it is
not possible to tell from the event alone which server to send it as.

(The reason for this is that sometimes a event needs to be sent on
behalf of a different remote matrix server)
This commit is contained in:
Mark Haines 2017-06-23 15:32:56 +01:00
parent e67f9401be
commit 3f1a1b806b
13 changed files with 137 additions and 43 deletions

View file

@ -42,7 +42,7 @@ func NewRoomserverProducer(kafkaURIs []string, topic string) (*RoomserverProduce
} }
// SendEvents writes the given events to the roomserver input log. The events are written with KindNew. // SendEvents writes the given events to the roomserver input log. The events are written with KindNew.
func (c *RoomserverProducer) SendEvents(events []gomatrixserverlib.Event) error { func (c *RoomserverProducer) SendEvents(events []gomatrixserverlib.Event, sendAsServer gomatrixserverlib.ServerName) error {
eventIDs := make([]string, len(events)) eventIDs := make([]string, len(events))
ires := make([]api.InputRoomEvent, len(events)) ires := make([]api.InputRoomEvent, len(events))
for i, event := range events { for i, event := range events {
@ -50,6 +50,7 @@ func (c *RoomserverProducer) SendEvents(events []gomatrixserverlib.Event) error
Kind: api.KindNew, Kind: api.KindNew,
Event: event.JSON(), Event: event.JSON(),
AuthEventIDs: authEventIDs(event), AuthEventIDs: authEventIDs(event),
SendAsServer: string(sendAsServer),
} }
eventIDs[i] = event.EventID() eventIDs[i] = event.EventID()
} }

View file

@ -188,7 +188,7 @@ func createRoom(req *http.Request, device *authtypes.Device, cfg config.Dendrite
} }
// send events to the room server // send events to the room server
if err := producer.SendEvents(builtEvents); err != nil { if err := producer.SendEvents(builtEvents, cfg.Matrix.ServerName); err != nil {
return httputil.LogThenError(req, err) return httputil.LogThenError(req, err)
} }

View file

@ -93,6 +93,7 @@ func SendEvent(
refs = append(refs, e.EventReference()) refs = append(refs, e.EventReference())
} }
builder.AuthEvents = refs builder.AuthEvents = refs
builder.Depth = queryRes.Depth
eventID := fmt.Sprintf("$%s:%s", util.RandomString(16), cfg.Matrix.ServerName) eventID := fmt.Sprintf("$%s:%s", util.RandomString(16), cfg.Matrix.ServerName)
e, err := builder.Build( e, err := builder.Build(
eventID, time.Now(), cfg.Matrix.ServerName, cfg.Matrix.KeyID, cfg.Matrix.PrivateKey, eventID, time.Now(), cfg.Matrix.ServerName, cfg.Matrix.KeyID, cfg.Matrix.PrivateKey,
@ -115,7 +116,7 @@ func SendEvent(
} }
// pass the new event to the roomserver // pass the new event to the roomserver
if err := producer.SendEvents([]gomatrixserverlib.Event{e}); err != nil { if err := producer.SendEvents([]gomatrixserverlib.Event{e}, cfg.Matrix.ServerName); err != nil {
return httputil.LogThenError(req, err) return httputil.LogThenError(req, err)
} }

View file

@ -340,7 +340,8 @@ func main() {
"state_key":"@richvdh:matrix.org", "state_key":"@richvdh:matrix.org",
"type":"m.room.member" "type":"m.room.member"
}, },
"VisibilityEventIDs":null, "StateBeforeRemovesEventIDs":["$1463671339126270PnVwC:matrix.org"],
"StateBeforeAddsEventIDs":null,
"LatestEventIDs":["$1463671339126270PnVwC:matrix.org"], "LatestEventIDs":["$1463671339126270PnVwC:matrix.org"],
"AddsStateEventIDs":["$1463671337126266wrSBX:matrix.org", "$1463671339126270PnVwC:matrix.org"], "AddsStateEventIDs":["$1463671337126266wrSBX:matrix.org", "$1463671339126270PnVwC:matrix.org"],
"RemovesStateEventIDs":null, "RemovesStateEventIDs":null,

View file

@ -159,7 +159,7 @@ func (t *txnReq) processEvent(e gomatrixserverlib.Event) error {
// TODO: Check that the event is allowed by its auth_events. // TODO: Check that the event is allowed by its auth_events.
// pass the event to the roomserver // pass the event to the roomserver
if err := t.producer.SendEvents([]gomatrixserverlib.Event{e}); err != nil { if err := t.producer.SendEvents([]gomatrixserverlib.Event{e}, ""); err != nil {
return err return err
} }

View file

@ -62,6 +62,9 @@ type InputRoomEvent struct {
// These are only used if HasState is true. // These are only used if HasState is true.
// The list can be empty, for example when storing the first event in a room. // The list can be empty, for example when storing the first event in a room.
StateEventIDs []string StateEventIDs []string
// The server name to use to push this event to other servers.
// Or empty if this event shouldn't be pushed to other servers.
SendAsServer string
} }
// UnmarshalJSON implements json.Unmarshaller // UnmarshalJSON implements json.Unmarshaller
@ -76,6 +79,7 @@ func (ire *InputRoomEvent) UnmarshalJSON(data []byte) error {
AuthEventIDs []string AuthEventIDs []string
StateEventIDs []string StateEventIDs []string
HasState bool HasState bool
SendAsServer string
} }
if err := json.Unmarshal(data, &content); err != nil { if err := json.Unmarshal(data, &content); err != nil {
return err return err
@ -84,6 +88,7 @@ func (ire *InputRoomEvent) UnmarshalJSON(data []byte) error {
ire.AuthEventIDs = content.AuthEventIDs ire.AuthEventIDs = content.AuthEventIDs
ire.StateEventIDs = content.StateEventIDs ire.StateEventIDs = content.StateEventIDs
ire.HasState = content.HasState ire.HasState = content.HasState
ire.SendAsServer = content.SendAsServer
if content.Event != nil { if content.Event != nil {
ire.Event = []byte(*content.Event) ire.Event = []byte(*content.Event)
} }
@ -103,12 +108,14 @@ func (ire InputRoomEvent) MarshalJSON() ([]byte, error) {
AuthEventIDs []string AuthEventIDs []string
StateEventIDs []string StateEventIDs []string
HasState bool HasState bool
SendAsServer string
}{ }{
Kind: ire.Kind, Kind: ire.Kind,
AuthEventIDs: ire.AuthEventIDs, AuthEventIDs: ire.AuthEventIDs,
StateEventIDs: ire.StateEventIDs, StateEventIDs: ire.StateEventIDs,
Event: &event, Event: &event,
HasState: ire.HasState, HasState: ire.HasState,
SendAsServer: ire.SendAsServer,
} }
return json.Marshal(&content) return json.Marshal(&content)
} }

View file

@ -22,9 +22,6 @@ import (
type OutputRoomEvent struct { type OutputRoomEvent struct {
// The JSON bytes of the event. // The JSON bytes of the event.
Event []byte Event []byte
// The state event IDs needed to determine who can see this event.
// This can be used to tell which users to send the event to.
VisibilityEventIDs []string
// The latest events in the room after this event. // The latest events in the room after this event.
// This can be used to set the prev events for new events in the room. // This can be used to set the prev events for new events in the room.
// This also can be used to get the full current state after this event. // This also can be used to get the full current state after this event.
@ -43,6 +40,17 @@ type OutputRoomEvent struct {
// If the LastSentEventID doesn't match what they were expecting it to be // If the LastSentEventID doesn't match what they were expecting it to be
// they can use the LatestEventIDs to request the full current state. // they can use the LatestEventIDs to request the full current state.
LastSentEventID string LastSentEventID string
// The state event IDs that are part of the state at the event, but not
// part of the current state. Together with the StateBeforeRemovesEventIDs
// this can be used to construct the state before the event from the
// current state.
StateBeforeAddsEventIDs []string
// The state event IDs that are part of the current state, but not part
// of the state at the event.
StateBeforeRemovesEventIDs []string
// The server name to use to push this event to other servers.
// Or empty if this event shouldn't be pushed to other servers.
SendAsServer string
} }
// UnmarshalJSON implements json.Unmarshaller // UnmarshalJSON implements json.Unmarshaller
@ -52,12 +60,14 @@ func (ore *OutputRoomEvent) UnmarshalJSON(data []byte) error {
// We use json.RawMessage so that the event JSON is sent as JSON rather than // We use json.RawMessage so that the event JSON is sent as JSON rather than
// being base64 encoded which is the default for []byte. // being base64 encoded which is the default for []byte.
var content struct { var content struct {
Event *json.RawMessage Event *json.RawMessage
VisibilityEventIDs []string LatestEventIDs []string
LatestEventIDs []string AddsStateEventIDs []string
AddsStateEventIDs []string RemovesStateEventIDs []string
RemovesStateEventIDs []string LastSentEventID string
LastSentEventID string StateBeforeAddsEventIDs []string
StateBeforeRemovesEventIDs []string
SendAsServer string
} }
if err := json.Unmarshal(data, &content); err != nil { if err := json.Unmarshal(data, &content); err != nil {
return err return err
@ -65,11 +75,13 @@ func (ore *OutputRoomEvent) UnmarshalJSON(data []byte) error {
if content.Event != nil { if content.Event != nil {
ore.Event = []byte(*content.Event) ore.Event = []byte(*content.Event)
} }
ore.VisibilityEventIDs = content.VisibilityEventIDs
ore.LatestEventIDs = content.LatestEventIDs ore.LatestEventIDs = content.LatestEventIDs
ore.AddsStateEventIDs = content.AddsStateEventIDs ore.AddsStateEventIDs = content.AddsStateEventIDs
ore.RemovesStateEventIDs = content.RemovesStateEventIDs ore.RemovesStateEventIDs = content.RemovesStateEventIDs
ore.LastSentEventID = content.LastSentEventID ore.LastSentEventID = content.LastSentEventID
ore.StateBeforeAddsEventIDs = content.StateBeforeAddsEventIDs
ore.StateBeforeRemovesEventIDs = content.StateBeforeRemovesEventIDs
ore.SendAsServer = content.SendAsServer
return nil return nil
} }
@ -81,19 +93,23 @@ func (ore OutputRoomEvent) MarshalJSON() ([]byte, error) {
// being base64 encoded which is the default for []byte. // being base64 encoded which is the default for []byte.
event := json.RawMessage(ore.Event) event := json.RawMessage(ore.Event)
content := struct { content := struct {
Event *json.RawMessage Event *json.RawMessage
VisibilityEventIDs []string LatestEventIDs []string
LatestEventIDs []string AddsStateEventIDs []string
AddsStateEventIDs []string RemovesStateEventIDs []string
RemovesStateEventIDs []string LastSentEventID string
LastSentEventID string StateBeforeAddsEventIDs []string
StateBeforeRemovesEventIDs []string
SendAsServer string
}{ }{
Event: &event, Event: &event,
VisibilityEventIDs: ore.VisibilityEventIDs, LatestEventIDs: ore.LatestEventIDs,
LatestEventIDs: ore.LatestEventIDs, AddsStateEventIDs: ore.AddsStateEventIDs,
AddsStateEventIDs: ore.AddsStateEventIDs, RemovesStateEventIDs: ore.RemovesStateEventIDs,
RemovesStateEventIDs: ore.RemovesStateEventIDs, LastSentEventID: ore.LastSentEventID,
LastSentEventID: ore.LastSentEventID, StateBeforeAddsEventIDs: ore.StateBeforeAddsEventIDs,
StateBeforeRemovesEventIDs: ore.StateBeforeRemovesEventIDs,
SendAsServer: ore.SendAsServer,
} }
return json.Marshal(&content) return json.Marshal(&content)
} }

View file

@ -43,6 +43,9 @@ type QueryLatestEventsAndStateResponse struct {
// The state events requested. // The state events requested.
// This list will be in an arbitrary order. // This list will be in an arbitrary order.
StateEvents []gomatrixserverlib.Event StateEvents []gomatrixserverlib.Event
// The depth of the latest events.
// This is one greater than the depths of the latest events.
Depth int64
} }
// QueryStateAfterEventsRequest is a request to QueryStateAfterEvents // QueryStateAfterEventsRequest is a request to QueryStateAfterEvents

View file

@ -102,7 +102,7 @@ func processRoomEvent(db RoomEventDatabase, ow OutputRoomEventWriter, input api.
} }
// Update the extremities of the event graph for the room // Update the extremities of the event graph for the room
if err := updateLatestEvents(db, ow, roomNID, stateAtEvent, event); err != nil { if err := updateLatestEvents(db, ow, roomNID, stateAtEvent, event, input.SendAsServer); err != nil {
return err return err
} }

View file

@ -20,6 +20,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/state"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
) )
// updateLatestEvents updates the list of latest events for this room in the database and writes the // updateLatestEvents updates the list of latest events for this room in the database and writes the
@ -39,7 +40,12 @@ import (
// 7 <----- latest // 7 <----- latest
// //
func updateLatestEvents( func updateLatestEvents(
db RoomEventDatabase, ow OutputRoomEventWriter, roomNID types.RoomNID, stateAtEvent types.StateAtEvent, event gomatrixserverlib.Event, db RoomEventDatabase,
ow OutputRoomEventWriter,
roomNID types.RoomNID,
stateAtEvent types.StateAtEvent,
event gomatrixserverlib.Event,
sendAsServer string,
) (err error) { ) (err error) {
updater, err := db.GetLatestEventsForUpdate(roomNID) updater, err := db.GetLatestEventsForUpdate(roomNID)
if err != nil { if err != nil {
@ -59,12 +65,18 @@ func updateLatestEvents(
} }
}() }()
err = doUpdateLatestEvents(db, updater, ow, roomNID, stateAtEvent, event) err = doUpdateLatestEvents(db, updater, ow, roomNID, stateAtEvent, event, sendAsServer)
return return
} }
func doUpdateLatestEvents( func doUpdateLatestEvents(
db RoomEventDatabase, updater types.RoomRecentEventsUpdater, ow OutputRoomEventWriter, roomNID types.RoomNID, stateAtEvent types.StateAtEvent, event gomatrixserverlib.Event, db RoomEventDatabase,
updater types.RoomRecentEventsUpdater,
ow OutputRoomEventWriter,
roomNID types.RoomNID,
stateAtEvent types.StateAtEvent,
event gomatrixserverlib.Event,
sendAsServer string,
) error { ) error {
var err error var err error
var prevEvents []gomatrixserverlib.EventReference var prevEvents []gomatrixserverlib.EventReference
@ -110,6 +122,13 @@ func doUpdateLatestEvents(
return err return err
} }
stateBeforeEventRemoves, stateBeforeEventAdds, err := state.DifferenceBetweeenStateSnapshots(
db, newStateNID, stateAtEvent.BeforeStateSnapshotNID,
)
if err != nil {
return err
}
// Send the event to the output logs. // Send the event to the output logs.
// We do this inside the database transaction to ensure that we only mark an event as sent if we sent it. // We do this inside the database transaction to ensure that we only mark an event as sent if we sent it.
// (n.b. this means that it's possible that the same event will be sent twice if the transaction fails but // (n.b. this means that it's possible that the same event will be sent twice if the transaction fails but
@ -118,7 +137,10 @@ func doUpdateLatestEvents(
// send the event asynchronously but we would need to ensure that 1) the events are written to the log in // send the event asynchronously but we would need to ensure that 1) the events are written to the log in
// the correct order, 2) that pending writes are resent across restarts. In order to avoid writing all the // the correct order, 2) that pending writes are resent across restarts. In order to avoid writing all the
// necessary bookkeeping we'll keep the event sending synchronous for now. // necessary bookkeeping we'll keep the event sending synchronous for now.
if err = writeEvent(db, ow, lastEventIDSent, event, newLatest, removed, added); err != nil { if err = writeEvent(
db, ow, lastEventIDSent, event, newLatest, removed, added,
stateBeforeEventRemoves, stateBeforeEventAdds, sendAsServer,
); err != nil {
return err return err
} }
@ -170,6 +192,8 @@ func writeEvent(
db RoomEventDatabase, ow OutputRoomEventWriter, lastEventIDSent string, db RoomEventDatabase, ow OutputRoomEventWriter, lastEventIDSent string,
event gomatrixserverlib.Event, latest []types.StateAtEventAndReference, event gomatrixserverlib.Event, latest []types.StateAtEventAndReference,
removed, added []types.StateEntry, removed, added []types.StateEntry,
stateBeforeEventRemoves, stateBeforeEventAdds []types.StateEntry,
sendAsServer string,
) error { ) error {
latestEventIDs := make([]string, len(latest)) latestEventIDs := make([]string, len(latest))
@ -190,6 +214,13 @@ func writeEvent(
for _, entry := range removed { for _, entry := range removed {
stateEventNIDs = append(stateEventNIDs, entry.EventNID) stateEventNIDs = append(stateEventNIDs, entry.EventNID)
} }
for _, entry := range stateBeforeEventRemoves {
stateEventNIDs = append(stateEventNIDs, entry.EventNID)
}
for _, entry := range stateBeforeEventAdds {
stateEventNIDs = append(stateEventNIDs, entry.EventNID)
}
stateEventNIDs = stateEventNIDs[:util.SortAndUnique(eventNIDSorter(stateEventNIDs))]
eventIDMap, err := db.EventIDs(stateEventNIDs) eventIDMap, err := db.EventIDs(stateEventNIDs)
if err != nil { if err != nil {
return err return err
@ -200,7 +231,19 @@ func writeEvent(
for _, entry := range removed { for _, entry := range removed {
ore.RemovesStateEventIDs = append(ore.RemovesStateEventIDs, eventIDMap[entry.EventNID]) ore.RemovesStateEventIDs = append(ore.RemovesStateEventIDs, eventIDMap[entry.EventNID])
} }
for _, entry := range stateBeforeEventRemoves {
ore.StateBeforeRemovesEventIDs = append(ore.StateBeforeRemovesEventIDs, eventIDMap[entry.EventNID])
}
for _, entry := range stateBeforeEventAdds {
ore.StateBeforeAddsEventIDs = append(ore.StateBeforeAddsEventIDs, eventIDMap[entry.EventNID])
}
ore.SendAsServer = sendAsServer
// TODO: Fill out VisibilityStateIDs // TODO: Fill out VisibilityStateIDs
return ow.WriteOutputRoomEvent(ore) return ow.WriteOutputRoomEvent(ore)
} }
type eventNIDSorter []types.EventNID
func (s eventNIDSorter) Len() int { return len(s) }
func (s eventNIDSorter) Less(i, j int) bool { return s[i] < s[j] }
func (s eventNIDSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }

View file

@ -34,7 +34,7 @@ type RoomserverQueryAPIDatabase interface {
RoomNID(roomID string) (types.RoomNID, error) RoomNID(roomID string) (types.RoomNID, error)
// Lookup event references for the latest events in the room and the current state snapshot. // Lookup event references for the latest events in the room and the current state snapshot.
// Returns an error if there was a problem talking to the database. // Returns an error if there was a problem talking to the database.
LatestEventIDs(roomNID types.RoomNID) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, error) LatestEventIDs(roomNID types.RoomNID) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, int64, error)
// Lookup the numeric IDs for a list of events. // Lookup the numeric IDs for a list of events.
// Returns an error if there was a problem talking to the database. // Returns an error if there was a problem talking to the database.
EventNIDs(eventIDs []string) (map[string]types.EventNID, error) EventNIDs(eventIDs []string) (map[string]types.EventNID, error)
@ -60,7 +60,7 @@ func (r *RoomserverQueryAPI) QueryLatestEventsAndState(
} }
response.RoomExists = true response.RoomExists = true
var currentStateSnapshotNID types.StateSnapshotNID var currentStateSnapshotNID types.StateSnapshotNID
response.LatestEvents, currentStateSnapshotNID, err = r.DB.LatestEventIDs(roomNID) response.LatestEvents, currentStateSnapshotNID, response.Depth, err = r.DB.LatestEventIDs(roomNID)
if err != nil { if err != nil {
return err return err
} }

View file

@ -46,7 +46,9 @@ CREATE TABLE IF NOT EXISTS events (
-- part of the event graph -- part of the event graph
-- Since many different events can have the same state we store the -- Since many different events can have the same state we store the
-- state into a separate state table and refer to it by numeric ID. -- state into a separate state table and refer to it by numeric ID.
state_snapshot_nid bigint NOT NULL DEFAULT 0, state_snapshot_nid BIGINT NOT NULL DEFAULT 0,
-- Depth of the event in the event graph.
depth BIGINT NOT NULL,
-- The textual event id. -- The textual event id.
-- Used to lookup the numeric ID when processing requests. -- Used to lookup the numeric ID when processing requests.
-- Needed for state resolution. -- Needed for state resolution.
@ -61,8 +63,8 @@ CREATE TABLE IF NOT EXISTS events (
` `
const insertEventSQL = "" + const insertEventSQL = "" +
"INSERT INTO events (room_nid, event_type_nid, event_state_key_nid, event_id, reference_sha256, auth_event_nids)" + "INSERT INTO events (room_nid, event_type_nid, event_state_key_nid, event_id, reference_sha256, auth_event_nids, depth)" +
" VALUES ($1, $2, $3, $4, $5, $6)" + " VALUES ($1, $2, $3, $4, $5, $6, $7)" +
" ON CONFLICT ON CONSTRAINT event_id_unique" + " ON CONFLICT ON CONSTRAINT event_id_unique" +
" DO NOTHING" + " DO NOTHING" +
" RETURNING event_nid, state_snapshot_nid" " RETURNING event_nid, state_snapshot_nid"
@ -107,6 +109,9 @@ const bulkSelectEventIDSQL = "" +
const bulkSelectEventNIDSQL = "" + const bulkSelectEventNIDSQL = "" +
"SELECT event_id, event_nid FROM events WHERE event_id = ANY($1)" "SELECT event_id, event_nid FROM events WHERE event_id = ANY($1)"
const selectMaxEventDepthSQL = "" +
"SELECT COALESCE(MAX(depth) + 1, 0) FROM events WHERE event_nid = ANY($1)"
type eventStatements struct { type eventStatements struct {
insertEventStmt *sql.Stmt insertEventStmt *sql.Stmt
selectEventStmt *sql.Stmt selectEventStmt *sql.Stmt
@ -120,6 +125,7 @@ type eventStatements struct {
bulkSelectEventReferenceStmt *sql.Stmt bulkSelectEventReferenceStmt *sql.Stmt
bulkSelectEventIDStmt *sql.Stmt bulkSelectEventIDStmt *sql.Stmt
bulkSelectEventNIDStmt *sql.Stmt bulkSelectEventNIDStmt *sql.Stmt
selectMaxEventDepthStmt *sql.Stmt
} }
func (s *eventStatements) prepare(db *sql.DB) (err error) { func (s *eventStatements) prepare(db *sql.DB) (err error) {
@ -141,6 +147,7 @@ func (s *eventStatements) prepare(db *sql.DB) (err error) {
{&s.bulkSelectEventReferenceStmt, bulkSelectEventReferenceSQL}, {&s.bulkSelectEventReferenceStmt, bulkSelectEventReferenceSQL},
{&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL}, {&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL},
{&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL}, {&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL},
{&s.selectMaxEventDepthStmt, selectMaxEventDepthSQL},
}.prepare(db) }.prepare(db)
} }
@ -149,12 +156,13 @@ func (s *eventStatements) insertEvent(
eventID string, eventID string,
referenceSHA256 []byte, referenceSHA256 []byte,
authEventNIDs []types.EventNID, authEventNIDs []types.EventNID,
depth int64,
) (types.EventNID, types.StateSnapshotNID, error) { ) (types.EventNID, types.StateSnapshotNID, error) {
var eventNID int64 var eventNID int64
var stateNID int64 var stateNID int64
err := s.insertEventStmt.QueryRow( err := s.insertEventStmt.QueryRow(
int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID), eventID, referenceSHA256, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID), eventID, referenceSHA256,
eventNIDsAsArray(authEventNIDs), eventNIDsAsArray(authEventNIDs), depth,
).Scan(&eventNID, &stateNID) ).Scan(&eventNID, &stateNID)
return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err
} }
@ -357,6 +365,15 @@ func (s *eventStatements) bulkSelectEventNID(eventIDs []string) (map[string]type
return results, nil return results, nil
} }
func (s *eventStatements) selectMaxEventDepth(eventNIDs []types.EventNID) (int64, error) {
var result int64
err := s.selectMaxEventDepthStmt.QueryRow(eventNIDsAsArray(eventNIDs)).Scan(&result)
if err != nil {
return 0, err
}
return result, nil
}
func eventNIDsAsArray(eventNIDs []types.EventNID) pq.Int64Array { func eventNIDsAsArray(eventNIDs []types.EventNID) pq.Int64Array {
nids := make([]int64, len(eventNIDs)) nids := make([]int64, len(eventNIDs))
for i := range eventNIDs { for i := range eventNIDs {

View file

@ -87,6 +87,7 @@ func (d *Database) StoreEvent(event gomatrixserverlib.Event, authEventNIDs []typ
event.EventID(), event.EventID(),
event.EventReference().EventSHA256, event.EventReference().EventSHA256,
authEventNIDs, authEventNIDs,
event.Depth(),
); err != nil { ); err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
// We've already inserted the event so select the numeric event ID // We've already inserted the event so select the numeric event ID
@ -349,16 +350,20 @@ func (d *Database) RoomNID(roomID string) (types.RoomNID, error) {
} }
// LatestEventIDs implements query.RoomserverQueryAPIDB // LatestEventIDs implements query.RoomserverQueryAPIDB
func (d *Database) LatestEventIDs(roomNID types.RoomNID) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, error) { func (d *Database) LatestEventIDs(roomNID types.RoomNID) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, int64, error) {
eventNIDs, currentStateSnapshotNID, err := d.statements.selectLatestEventNIDs(roomNID) eventNIDs, currentStateSnapshotNID, err := d.statements.selectLatestEventNIDs(roomNID)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, 0, err
} }
references, err := d.statements.bulkSelectEventReference(eventNIDs) references, err := d.statements.bulkSelectEventReference(eventNIDs)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, 0, err
} }
return references, currentStateSnapshotNID, nil depth, err := d.statements.selectMaxEventDepth(eventNIDs)
if err != nil {
return nil, 0, 0, err
}
return references, currentStateSnapshotNID, depth, nil
} }
// StateEntriesForTuples implements state.RoomStateDatabase // StateEntriesForTuples implements state.RoomStateDatabase