Honour history_visibility when backfilling (#990)
* Make backfill work for shared history visibility * fetch missing state on backfill to remember snapshots correctly * Fix gmsl to not mux in auth events into room state * Whoops * Linting
This commit is contained in:
parent
458b364781
commit
4ad52c67ca
|
@ -16,7 +16,9 @@ package routing
|
|||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
@ -29,13 +31,20 @@ func GetEvent(
|
|||
request *gomatrixserverlib.FederationRequest,
|
||||
query api.RoomserverQueryAPI,
|
||||
eventID string,
|
||||
origin gomatrixserverlib.ServerName,
|
||||
) util.JSONResponse {
|
||||
event, err := getEvent(ctx, request, query, eventID)
|
||||
if err != nil {
|
||||
return *err
|
||||
}
|
||||
|
||||
return util.JSONResponse{Code: http.StatusOK, JSON: event}
|
||||
return util.JSONResponse{Code: http.StatusOK, JSON: gomatrixserverlib.Transaction{
|
||||
Origin: origin,
|
||||
OriginServerTS: gomatrixserverlib.AsTimestamp(time.Now()),
|
||||
PDUs: []json.RawMessage{
|
||||
event.JSON(),
|
||||
},
|
||||
}}
|
||||
}
|
||||
|
||||
// getEvent returns the requested event,
|
||||
|
|
|
@ -126,7 +126,7 @@ func Setup(
|
|||
return util.ErrorResponse(err)
|
||||
}
|
||||
return GetEvent(
|
||||
httpReq.Context(), request, query, vars["eventID"],
|
||||
httpReq.Context(), request, query, vars["eventID"], cfg.Matrix.ServerName,
|
||||
)
|
||||
},
|
||||
)).Methods(http.MethodGet)
|
||||
|
|
2
go.mod
2
go.mod
|
@ -17,7 +17,7 @@ require (
|
|||
github.com/matrix-org/go-http-js-libp2p v0.0.0-20200318135427-31631a9ef51f
|
||||
github.com/matrix-org/go-sqlite3-js v0.0.0-20200325174927-327088cdef10
|
||||
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20200429143250-5df6426424bd
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20200429162354-392f0b1b7421
|
||||
github.com/matrix-org/naffka v0.0.0-20200422140631-181f1ee7401f
|
||||
github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7
|
||||
github.com/mattn/go-sqlite3 v2.0.2+incompatible
|
||||
|
|
4
go.sum
4
go.sum
|
@ -367,8 +367,8 @@ github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26 h1:Hr3zjRsq2bh
|
|||
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0=
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20200124100636-0c2ec91d1df5 h1:kmRjpmFOenVpOaV/DRlo9p6z/IbOKlUC+hhKsAAh8Qg=
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20200124100636-0c2ec91d1df5/go.mod h1:FsKa2pWE/bpQql9H7U4boOPXFoJX/QcqaZZ6ijLkaZI=
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20200429143250-5df6426424bd h1:YA+1/Y/NK6dHAxRamybQiE6HToTC+5ddPCO4UI7pmH0=
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20200429143250-5df6426424bd/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU=
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20200429162354-392f0b1b7421 h1:4zP29YlpfEtJ9a7sZ33Mf0FJInD2N3/KzDcLa62bRKc=
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20200429162354-392f0b1b7421/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU=
|
||||
github.com/matrix-org/naffka v0.0.0-20200127221512-0716baaabaf1 h1:osLoFdOy+ChQqVUn2PeTDETFftVkl4w9t/OW18g3lnk=
|
||||
github.com/matrix-org/naffka v0.0.0-20200127221512-0716baaabaf1/go.mod h1:cXoYQIENbdWIQHt1SyCo6Bl3C3raHwJ0wgVrXHSqf+A=
|
||||
github.com/matrix-org/naffka v0.0.0-20200422140631-181f1ee7401f h1:pRz4VTiRCO4zPlEMc3ESdUOcW4PXHH4Kj+YDz1XyE+Y=
|
||||
|
|
|
@ -27,7 +27,7 @@ func IsServerAllowed(
|
|||
serverCurrentlyInRoom bool,
|
||||
authEvents []gomatrixserverlib.Event,
|
||||
) bool {
|
||||
historyVisibility := historyVisibilityForRoom(authEvents)
|
||||
historyVisibility := HistoryVisibilityForRoom(authEvents)
|
||||
|
||||
// 1. If the history_visibility was set to world_readable, allow.
|
||||
if historyVisibility == "world_readable" {
|
||||
|
@ -52,7 +52,7 @@ func IsServerAllowed(
|
|||
return false
|
||||
}
|
||||
|
||||
func historyVisibilityForRoom(authEvents []gomatrixserverlib.Event) string {
|
||||
func HistoryVisibilityForRoom(authEvents []gomatrixserverlib.Event) string {
|
||||
// https://matrix.org/docs/spec/client_server/r0.6.0#id87
|
||||
// By default if no history_visibility is set, or if the value is not understood, the visibility is assumed to be shared.
|
||||
visibility := "shared"
|
||||
|
|
|
@ -3,6 +3,7 @@ package query
|
|||
import (
|
||||
"context"
|
||||
|
||||
"github.com/matrix-org/dendrite/roomserver/auth"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage"
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
@ -62,9 +63,9 @@ FederationHit:
|
|||
logrus.WithField("event_id", targetEvent.EventID()).Info("Requesting /state_ids at event")
|
||||
for _, srv := range b.servers { // hit any valid server
|
||||
c := gomatrixserverlib.FederatedStateProvider{
|
||||
FedClient: b.fedClient,
|
||||
AuthEventsOnly: false,
|
||||
Server: srv,
|
||||
FedClient: b.fedClient,
|
||||
RememberAuthEvents: false,
|
||||
Server: srv,
|
||||
}
|
||||
res, err := c.StateIDsBeforeEvent(ctx, targetEvent)
|
||||
if err != nil {
|
||||
|
@ -114,7 +115,9 @@ func (b *backfillRequester) calculateNewStateIDs(targetEvent, prevEvent gomatrix
|
|||
return nil
|
||||
}
|
||||
|
||||
func (b *backfillRequester) StateBeforeEvent(ctx context.Context, roomVer gomatrixserverlib.RoomVersion, event gomatrixserverlib.HeaderedEvent, eventIDs []string) (map[string]*gomatrixserverlib.Event, error) {
|
||||
func (b *backfillRequester) StateBeforeEvent(ctx context.Context, roomVer gomatrixserverlib.RoomVersion,
|
||||
event gomatrixserverlib.HeaderedEvent, eventIDs []string) (map[string]*gomatrixserverlib.Event, error) {
|
||||
|
||||
// try to fetch the events from the database first
|
||||
events, err := b.ProvideEvents(roomVer, eventIDs)
|
||||
if err != nil {
|
||||
|
@ -133,9 +136,9 @@ func (b *backfillRequester) StateBeforeEvent(ctx context.Context, roomVer gomatr
|
|||
}
|
||||
|
||||
c := gomatrixserverlib.FederatedStateProvider{
|
||||
FedClient: b.fedClient,
|
||||
AuthEventsOnly: false,
|
||||
Server: b.servers[0],
|
||||
FedClient: b.fedClient,
|
||||
RememberAuthEvents: false,
|
||||
Server: b.servers[0],
|
||||
}
|
||||
result, err := c.StateBeforeEvent(ctx, roomVer, event, eventIDs)
|
||||
if err != nil {
|
||||
|
@ -160,18 +163,33 @@ func (b *backfillRequester) ServersAtEvent(ctx context.Context, roomID, eventID
|
|||
return
|
||||
}
|
||||
|
||||
stateEntries, err := stateBeforeEvent(ctx, b.db, NIDs[eventID])
|
||||
if err != nil {
|
||||
logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to load state before event")
|
||||
return
|
||||
}
|
||||
|
||||
// possibly return all joined servers depending on history visiblity
|
||||
memberEventsFromVis, err := joinEventsFromHistoryVisibility(ctx, b.db, roomID, stateEntries)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Error("ServersAtEvent: failed calculate servers from history visibility rules")
|
||||
return
|
||||
}
|
||||
logrus.Infof("ServersAtEvent including %d current events from history visibility", len(memberEventsFromVis))
|
||||
|
||||
// Retrieve all "m.room.member" state events of "join" membership, which
|
||||
// contains the list of users in the room before the event, therefore all
|
||||
// the servers in it at that moment.
|
||||
events, err := getMembershipsBeforeEventNID(ctx, b.db, NIDs[eventID], true)
|
||||
memberEvents, err := getMembershipsAtState(ctx, b.db, stateEntries, true)
|
||||
if err != nil {
|
||||
logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to get memberships before event")
|
||||
return
|
||||
}
|
||||
memberEvents = append(memberEvents, memberEventsFromVis...)
|
||||
|
||||
// Store the server names in a temporary map to avoid duplicates.
|
||||
serverSet := make(map[gomatrixserverlib.ServerName]bool)
|
||||
for _, event := range events {
|
||||
for _, event := range memberEvents {
|
||||
serverSet[event.Origin()] = true
|
||||
}
|
||||
for server := range serverSet {
|
||||
|
@ -186,7 +204,9 @@ func (b *backfillRequester) ServersAtEvent(ctx context.Context, roomID, eventID
|
|||
|
||||
// Backfill performs a backfill request to the given server.
|
||||
// https://matrix.org/docs/spec/server_server/latest#get-matrix-federation-v1-backfill-roomid
|
||||
func (b *backfillRequester) Backfill(ctx context.Context, server gomatrixserverlib.ServerName, roomID string, fromEventIDs []string, limit int) (*gomatrixserverlib.Transaction, error) {
|
||||
func (b *backfillRequester) Backfill(ctx context.Context, server gomatrixserverlib.ServerName, roomID string,
|
||||
fromEventIDs []string, limit int) (*gomatrixserverlib.Transaction, error) {
|
||||
|
||||
tx, err := b.fedClient.Backfill(ctx, server, roomID, limit, fromEventIDs)
|
||||
return &tx, err
|
||||
}
|
||||
|
@ -215,3 +235,44 @@ func (b *backfillRequester) ProvideEvents(roomVer gomatrixserverlib.RoomVersion,
|
|||
}
|
||||
return events, nil
|
||||
}
|
||||
|
||||
// joinEventsFromHistoryVisibility returns all CURRENTLY joined members if the provided state indicated a 'shared' history visibility.
|
||||
// TODO: Long term we probably want a history_visibility table which stores eventNID | visibility_enum so we can just
|
||||
// pull all events and then filter by that table.
|
||||
func joinEventsFromHistoryVisibility(
|
||||
ctx context.Context, db storage.Database, roomID string, stateEntries []types.StateEntry) ([]types.Event, error) {
|
||||
|
||||
var eventNIDs []types.EventNID
|
||||
for _, entry := range stateEntries {
|
||||
// Filter the events to retrieve to only keep the membership events
|
||||
if entry.EventTypeNID == types.MRoomHistoryVisibilityNID && entry.EventStateKeyNID == types.EmptyStateKeyNID {
|
||||
eventNIDs = append(eventNIDs, entry.EventNID)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Get all of the events in this state
|
||||
stateEvents, err := db.Events(ctx, eventNIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
events := make([]gomatrixserverlib.Event, len(stateEvents))
|
||||
for i := range stateEvents {
|
||||
events[i] = stateEvents[i].Event
|
||||
}
|
||||
visibility := auth.HistoryVisibilityForRoom(events)
|
||||
if visibility != "shared" {
|
||||
logrus.Infof("ServersAtEvent history visibility not shared: %s", visibility)
|
||||
return nil, nil
|
||||
}
|
||||
// get joined members
|
||||
roomNID, err := db.RoomNID(ctx, roomID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
joinEventNIDs, err := db.GetMembershipEventNIDsForRoom(ctx, roomNID, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return db.Events(ctx, joinEventNIDs)
|
||||
}
|
||||
|
|
|
@ -277,6 +277,7 @@ func (r *RoomserverQueryAPI) QueryMembershipsForRoom(
|
|||
response.JoinEvents = []gomatrixserverlib.ClientEvent{}
|
||||
|
||||
var events []types.Event
|
||||
var stateEntries []types.StateEntry
|
||||
if stillInRoom {
|
||||
var eventNIDs []types.EventNID
|
||||
eventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, roomNID, request.JoinedOnly)
|
||||
|
@ -286,7 +287,12 @@ func (r *RoomserverQueryAPI) QueryMembershipsForRoom(
|
|||
|
||||
events, err = r.DB.Events(ctx, eventNIDs)
|
||||
} else {
|
||||
events, err = getMembershipsBeforeEventNID(ctx, r.DB, membershipEventNID, request.JoinedOnly)
|
||||
stateEntries, err = stateBeforeEvent(ctx, r.DB, membershipEventNID)
|
||||
if err != nil {
|
||||
logrus.WithField("membership_event_nid", membershipEventNID).WithError(err).Error("failed to load state before event")
|
||||
return err
|
||||
}
|
||||
events, err = getMembershipsAtState(ctx, r.DB, stateEntries, request.JoinedOnly)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
|
@ -301,15 +307,8 @@ func (r *RoomserverQueryAPI) QueryMembershipsForRoom(
|
|||
return nil
|
||||
}
|
||||
|
||||
// getMembershipsBeforeEventNID takes the numeric ID of an event and fetches the state
|
||||
// of the event's room as it was when this event was fired, then filters the state events to
|
||||
// only keep the "m.room.member" events with a "join" membership. These events are returned.
|
||||
// Returns an error if there was an issue fetching the events.
|
||||
func getMembershipsBeforeEventNID(
|
||||
ctx context.Context, db storage.Database, eventNID types.EventNID, joinedOnly bool,
|
||||
) ([]types.Event, error) {
|
||||
func stateBeforeEvent(ctx context.Context, db storage.Database, eventNID types.EventNID) ([]types.StateEntry, error) {
|
||||
roomState := state.NewStateResolution(db)
|
||||
events := []types.Event{}
|
||||
// Lookup the event NID
|
||||
eIDs, err := db.EventIDs(ctx, []types.EventNID{eventNID})
|
||||
if err != nil {
|
||||
|
@ -323,10 +322,15 @@ func getMembershipsBeforeEventNID(
|
|||
}
|
||||
|
||||
// Fetch the state as it was when this event was fired
|
||||
stateEntries, err := roomState.LoadCombinedStateAfterEvents(ctx, prevState)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return roomState.LoadCombinedStateAfterEvents(ctx, prevState)
|
||||
}
|
||||
|
||||
// getMembershipsAtState filters the state events to
|
||||
// only keep the "m.room.member" events with a "join" membership. These events are returned.
|
||||
// Returns an error if there was an issue fetching the events.
|
||||
func getMembershipsAtState(
|
||||
ctx context.Context, db storage.Database, stateEntries []types.StateEntry, joinedOnly bool,
|
||||
) ([]types.Event, error) {
|
||||
|
||||
var eventNIDs []types.EventNID
|
||||
for _, entry := range stateEntries {
|
||||
|
@ -347,6 +351,7 @@ func getMembershipsBeforeEventNID(
|
|||
}
|
||||
|
||||
// Filter the events to only keep the "join" membership events
|
||||
var events []types.Event
|
||||
for _, event := range stateEvents {
|
||||
membership, err := event.Membership()
|
||||
if err != nil {
|
||||
|
@ -563,20 +568,29 @@ func (r *RoomserverQueryAPI) backfillViaFederation(ctx context.Context, req *api
|
|||
if !ok {
|
||||
// this should be impossible as all events returned must have pass Step 5 of the PDU checks
|
||||
// which requires a list of state IDs.
|
||||
logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to find state IDs for event which passed auth checks")
|
||||
logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to find state IDs for event which passed auth checks")
|
||||
continue
|
||||
}
|
||||
var entries []types.StateEntry
|
||||
if entries, err = r.DB.StateEntriesForEventIDs(ctx, stateIDs); err != nil {
|
||||
return err
|
||||
// attempt to fetch the missing events
|
||||
r.fetchAndStoreMissingEvents(ctx, roomVer, requester, stateIDs)
|
||||
// try again
|
||||
entries, err = r.DB.StateEntriesForEventIDs(ctx, stateIDs)
|
||||
if err != nil {
|
||||
logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to get state entries for event")
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
var beforeStateSnapshotNID types.StateSnapshotNID
|
||||
if beforeStateSnapshotNID, err = r.DB.AddState(ctx, roomNID, nil, entries); err != nil {
|
||||
logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to persist state entries to get snapshot nid")
|
||||
return err
|
||||
}
|
||||
util.GetLogger(ctx).Infof("Backfilled event %s (nid=%d) getting snapshot %v with entries %+v", ev.EventID(), ev.EventNID, beforeStateSnapshotNID, entries)
|
||||
if err = r.DB.SetState(ctx, ev.EventNID, beforeStateSnapshotNID); err != nil {
|
||||
logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to set state before event")
|
||||
logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to persist snapshot nid")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -608,6 +622,66 @@ func (r *RoomserverQueryAPI) isServerCurrentlyInRoom(ctx context.Context, server
|
|||
return auth.IsAnyUserOnServerWithMembership(serverName, gmslEvents, gomatrixserverlib.Join), nil
|
||||
}
|
||||
|
||||
// fetchAndStoreMissingEvents does a best-effort fetch and store of missing events specified in stateIDs. Returns no error as it is just
|
||||
// best effort.
|
||||
func (r *RoomserverQueryAPI) fetchAndStoreMissingEvents(ctx context.Context, roomVer gomatrixserverlib.RoomVersion,
|
||||
backfillRequester *backfillRequester, stateIDs []string) {
|
||||
|
||||
servers := backfillRequester.servers
|
||||
|
||||
// work out which are missing
|
||||
nidMap, err := r.DB.EventNIDs(ctx, stateIDs)
|
||||
if err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Warn("cannot query missing events")
|
||||
return
|
||||
}
|
||||
missingMap := make(map[string]*gomatrixserverlib.HeaderedEvent) // id -> event
|
||||
for _, id := range stateIDs {
|
||||
if _, ok := nidMap[id]; !ok {
|
||||
missingMap[id] = nil
|
||||
}
|
||||
}
|
||||
util.GetLogger(ctx).Infof("Fetching %d missing state events (from %d possible servers)", len(missingMap), len(servers))
|
||||
|
||||
// fetch the events from federation. Loop the servers first so if we find one that works we stick with them
|
||||
for _, srv := range servers {
|
||||
for id, ev := range missingMap {
|
||||
if ev != nil {
|
||||
continue // already found
|
||||
}
|
||||
logger := util.GetLogger(ctx).WithField("server", srv).WithField("event_id", id)
|
||||
res, err := r.FedClient.GetEvent(ctx, srv, id)
|
||||
if err != nil {
|
||||
logger.WithError(err).Warn("failed to get event from server")
|
||||
continue
|
||||
}
|
||||
loader := gomatrixserverlib.NewEventsLoader(roomVer, r.KeyRing, backfillRequester, backfillRequester.ProvideEvents, false)
|
||||
result, err := loader.LoadAndVerify(ctx, res.PDUs, gomatrixserverlib.TopologicalOrderByPrevEvents)
|
||||
if err != nil {
|
||||
logger.WithError(err).Warn("failed to load and verify event")
|
||||
continue
|
||||
}
|
||||
logger.Infof("returned %d PDUs which made events %+v", len(res.PDUs), result)
|
||||
for _, res := range result {
|
||||
if res.Error != nil {
|
||||
logger.WithError(err).Warn("event failed PDU checks")
|
||||
continue
|
||||
}
|
||||
missingMap[id] = res.Event
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var newEvents []gomatrixserverlib.HeaderedEvent
|
||||
for _, ev := range missingMap {
|
||||
if ev != nil {
|
||||
newEvents = append(newEvents, *ev)
|
||||
}
|
||||
}
|
||||
util.GetLogger(ctx).Infof("Persisting %d new events", len(newEvents))
|
||||
persistEvents(ctx, r.DB, newEvents)
|
||||
}
|
||||
|
||||
// TODO: Remove this when we have tests to assert correctness of this function
|
||||
// nolint:gocyclo
|
||||
func (r *RoomserverQueryAPI) scanEventTree(
|
||||
|
@ -857,7 +931,7 @@ func persistEvents(ctx context.Context, db storage.Database, events []gomatrixse
|
|||
var stateAtEvent types.StateAtEvent
|
||||
roomNID, stateAtEvent, err = db.StoreEvent(ctx, ev.Unwrap(), nil, authNids)
|
||||
if err != nil {
|
||||
logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to store backfilled event")
|
||||
logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to persist event")
|
||||
continue
|
||||
}
|
||||
backfilledEventMap[ev.EventID()] = types.Event{
|
||||
|
|
|
@ -86,7 +86,10 @@ func (v StateResolution) LoadStateAtEvent(
|
|||
) ([]types.StateEntry, error) {
|
||||
snapshotNID, err := v.db.SnapshotNIDFromEventID(ctx, eventID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("LoadStateAtEvent.SnapshotNIDFromEventID failed for event %s : %s", eventID, err)
|
||||
}
|
||||
if snapshotNID == 0 {
|
||||
return nil, fmt.Errorf("LoadStateAtEvent.SnapshotNIDFromEventID(%s) returned 0 NID, was this event stored?", eventID)
|
||||
}
|
||||
|
||||
stateEntries, err := v.LoadStateAtSnapshot(ctx, snapshotNID)
|
||||
|
|
|
@ -48,11 +48,6 @@ const insertEventSQL = `
|
|||
ON CONFLICT DO NOTHING;
|
||||
`
|
||||
|
||||
const insertEventResultSQL = `
|
||||
SELECT event_nid, state_snapshot_nid FROM roomserver_events
|
||||
WHERE rowid = last_insert_rowid();
|
||||
`
|
||||
|
||||
const selectEventSQL = "" +
|
||||
"SELECT event_nid, state_snapshot_nid FROM roomserver_events WHERE event_id = $1"
|
||||
|
||||
|
@ -102,7 +97,6 @@ const selectRoomNIDForEventNIDSQL = "" +
|
|||
type eventStatements struct {
|
||||
db *sql.DB
|
||||
insertEventStmt *sql.Stmt
|
||||
insertEventResultStmt *sql.Stmt
|
||||
selectEventStmt *sql.Stmt
|
||||
bulkSelectStateEventByIDStmt *sql.Stmt
|
||||
bulkSelectStateAtEventByIDStmt *sql.Stmt
|
||||
|
@ -126,7 +120,6 @@ func (s *eventStatements) prepare(db *sql.DB) (err error) {
|
|||
|
||||
return statementList{
|
||||
{&s.insertEventStmt, insertEventSQL},
|
||||
{&s.insertEventResultStmt, insertEventResultSQL},
|
||||
{&s.selectEventStmt, selectEventSQL},
|
||||
{&s.bulkSelectStateEventByIDStmt, bulkSelectStateEventByIDSQL},
|
||||
{&s.bulkSelectStateAtEventByIDStmt, bulkSelectStateAtEventByIDSQL},
|
||||
|
@ -152,19 +145,22 @@ func (s *eventStatements) insertEvent(
|
|||
referenceSHA256 []byte,
|
||||
authEventNIDs []types.EventNID,
|
||||
depth int64,
|
||||
) (types.EventNID, types.StateSnapshotNID, error) {
|
||||
var eventNID int64
|
||||
var stateNID int64
|
||||
var err error
|
||||
) (types.EventNID, error) {
|
||||
// attempt to insert: the last_row_id is the event NID
|
||||
insertStmt := common.TxStmt(txn, s.insertEventStmt)
|
||||
resultStmt := common.TxStmt(txn, s.insertEventResultStmt)
|
||||
if _, err = insertStmt.ExecContext(
|
||||
result, err := insertStmt.ExecContext(
|
||||
ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID),
|
||||
eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth,
|
||||
); err == nil {
|
||||
err = resultStmt.QueryRowContext(ctx).Scan(&eventNID, &stateNID)
|
||||
)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err
|
||||
modified, err := result.RowsAffected()
|
||||
if modified == 0 && err == nil {
|
||||
return 0, sql.ErrNoRows
|
||||
}
|
||||
eventNID, err := result.LastInsertId()
|
||||
return types.EventNID(eventNID), err
|
||||
}
|
||||
|
||||
func (s *eventStatements) selectEvent(
|
||||
|
|
|
@ -124,7 +124,7 @@ func (d *Database) StoreEvent(
|
|||
}
|
||||
}
|
||||
|
||||
if eventNID, stateNID, err = d.statements.insertEvent(
|
||||
if eventNID, err = d.statements.insertEvent(
|
||||
ctx,
|
||||
txn,
|
||||
roomNID,
|
||||
|
|
Loading…
Reference in a new issue