diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index 5a036d889..343d635bf 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -157,6 +157,13 @@ type Database interface { IgnoresForUser(ctx context.Context, userID string) (*types.IgnoredUsers, error) UpdateIgnoresForUser(ctx context.Context, userID string, ignores *types.IgnoredUsers) error + // SelectTopologicalEvent selects an event before and including the given position by eventType and roomID. Returns the found event and the topology token. + // If not event was found, returns nil and sql.ErrNoRows. + SelectTopologicalEvent(ctx context.Context, topologicalPosition int, eventType, roomID string) (*gomatrixserverlib.HeaderedEvent, types.TopologyToken, error) + // SelectMembershipForUser returns the membership of the user before and including the given position. If no membership can be found + // returns "leave", the topological position and no error. If an error occurs, other than sql.ErrNoRows, returns that and an empty + // string as the membership. + SelectMembershipForUser(ctx context.Context, roomID, userID string, pos int) (membership string, topologicalPos int, err error) } type Presence interface { diff --git a/syncapi/storage/postgres/memberships_table.go b/syncapi/storage/postgres/memberships_table.go index 00223c57a..0b959bab2 100644 --- a/syncapi/storage/postgres/memberships_table.go +++ b/syncapi/storage/postgres/memberships_table.go @@ -66,10 +66,14 @@ const selectMembershipCountSQL = "" + const selectHeroesSQL = "" + "SELECT DISTINCT user_id FROM syncapi_memberships WHERE room_id = $1 AND user_id != $2 AND membership = ANY($3) LIMIT 5" +const selectMembershipBeforeSQL = "" + + "SELECT membership, topological_pos FROM syncapi_memberships WHERE room_id = $1 and user_id = $2 AND topological_pos <= $3 ORDER BY topological_pos DESC LIMIT 1" + type membershipsStatements struct { - upsertMembershipStmt *sql.Stmt - selectMembershipCountStmt *sql.Stmt - selectHeroesStmt *sql.Stmt + upsertMembershipStmt *sql.Stmt + selectMembershipCountStmt *sql.Stmt + selectHeroesStmt *sql.Stmt + selectMembershipForUserStmt *sql.Stmt } func NewPostgresMembershipsTable(db *sql.DB) (tables.Memberships, error) { @@ -82,6 +86,7 @@ func NewPostgresMembershipsTable(db *sql.DB) (tables.Memberships, error) { {&s.upsertMembershipStmt, upsertMembershipSQL}, {&s.selectMembershipCountStmt, selectMembershipCountSQL}, {&s.selectHeroesStmt, selectHeroesSQL}, + {&s.selectMembershipForUserStmt, selectMembershipBeforeSQL}, }.Prepare(db) } @@ -132,3 +137,20 @@ func (s *membershipsStatements) SelectHeroes( } return heroes, rows.Err() } + +// SelectMembershipForUser returns the membership of the user before and including the given position. If no membership can be found +// returns "leave", the topological position and no error. If an error occurs, other than sql.ErrNoRows, returns that and an empty +// string as the membership. +func (s *membershipsStatements) SelectMembershipForUser( + ctx context.Context, txn *sql.Tx, roomID, userID string, pos int, +) (membership string, topologyPos int, err error) { + stmt := sqlutil.TxStmt(txn, s.selectMembershipForUserStmt) + err = stmt.QueryRowContext(ctx, roomID, userID, pos).Scan(&membership, &topologyPos) + if err != nil { + if err == sql.ErrNoRows { + return "leave", 0, nil + } + return "", 0, err + } + return membership, topologyPos, nil +} diff --git a/syncapi/storage/postgres/output_room_events_table.go b/syncapi/storage/postgres/output_room_events_table.go index d84d0cfa2..c3df0d7ea 100644 --- a/syncapi/storage/postgres/output_room_events_table.go +++ b/syncapi/storage/postgres/output_room_events_table.go @@ -164,6 +164,13 @@ const selectContextAfterEventSQL = "" + " AND ( $7::text[] IS NULL OR NOT(type LIKE ANY($7)) )" + " ORDER BY id ASC LIMIT $3" +const selectTopologicalEventSQL = "" + + "SELECT se.headered_event_json, st.topological_position, st.stream_position " + + " FROM syncapi_output_room_events_topology st " + + " JOIN syncapi_output_room_events se ON se.event_id = st.event_id " + + " WHERE st.room_id = $1 AND st.topological_position < $2 AND se.type = $3 " + + " ORDER BY st.topological_position DESC LIMIT 1" + type outputRoomEventsStatements struct { insertEventStmt *sql.Stmt selectEventsStmt *sql.Stmt @@ -178,6 +185,7 @@ type outputRoomEventsStatements struct { selectContextEventStmt *sql.Stmt selectContextBeforeEventStmt *sql.Stmt selectContextAfterEventStmt *sql.Stmt + selectTopologicalEventStmt *sql.Stmt } func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) { @@ -200,6 +208,7 @@ func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) { {&s.selectContextEventStmt, selectContextEventSQL}, {&s.selectContextBeforeEventStmt, selectContextBeforeEventSQL}, {&s.selectContextAfterEventStmt, selectContextAfterEventSQL}, + {&s.selectTopologicalEventStmt, selectTopologicalEventSQL}, }.Prepare(db) } @@ -580,6 +589,32 @@ func (s *outputRoomEventsStatements) SelectContextAfterEvent( return lastID, evts, rows.Err() } +// SelectTopologicalEvent selects an event before and including the given position by eventType and roomID. Returns the found event and the topology token. +// If not event was found, returns nil and sql.ErrNoRows. +func (s *outputRoomEventsStatements) SelectTopologicalEvent( + ctx context.Context, txn *sql.Tx, topologicalPosition int, eventType, roomID string, +) (*gomatrixserverlib.HeaderedEvent, types.TopologyToken, error) { + var ( + eventBytes []byte + token types.TopologyToken + ) + + err := sqlutil.TxStmtContext(ctx, txn, s.selectTopologicalEventStmt). + QueryRowContext(ctx, roomID, topologicalPosition, eventType). + Scan(&eventBytes, &token.Depth, &token.PDUPosition) + + if err != nil { + return nil, types.TopologyToken{}, err + } + + var res *gomatrixserverlib.HeaderedEvent + if err = json.Unmarshal(eventBytes, &res); err != nil { + return nil, types.TopologyToken{}, err + } + + return res, token, nil +} + func rowsToStreamEvents(rows *sql.Rows) ([]types.StreamEvent, error) { var result []types.StreamEvent for rows.Next() { diff --git a/syncapi/storage/postgres/syncserver.go b/syncapi/storage/postgres/syncserver.go index 9cfe7c070..d2693adc4 100644 --- a/syncapi/storage/postgres/syncserver.go +++ b/syncapi/storage/postgres/syncserver.go @@ -46,6 +46,10 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) if err != nil { return nil, err } + topology, err := NewPostgresTopologyTable(d.db) + if err != nil { + return nil, err + } events, err := NewPostgresEventsTable(d.db) if err != nil { return nil, err @@ -62,10 +66,6 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) if err != nil { return nil, err } - topology, err := NewPostgresTopologyTable(d.db) - if err != nil { - return nil, err - } backwardExtremities, err := NewPostgresBackwardsExtremitiesTable(d.db) if err != nil { return nil, err diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index ec5edd355..34eaca565 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -227,7 +227,7 @@ func (d *Database) AddPeek( return } -// DeletePeeks tracks the fact that a user has stopped peeking from the specified +// DeletePeek tracks the fact that a user has stopped peeking from the specified // device. If the peeks was successfully deleted this returns the stream ID it was // stored at. Returns an error if there was a problem communicating with the database. func (d *Database) DeletePeek( @@ -557,7 +557,7 @@ func (d *Database) RedactEvent(ctx context.Context, redactedEventID string, reda return err } -// Retrieve the backward topology position, i.e. the position of the +// GetBackwardTopologyPos retrieves the backward topology position, i.e. the position of the // oldest event in the room's topology. func (d *Database) GetBackwardTopologyPos( ctx context.Context, @@ -668,7 +668,7 @@ func (d *Database) fetchMissingStateEvents( return events, nil } -// getStateDeltas returns the state deltas between fromPos and toPos, +// GetStateDeltas returns the state deltas between fromPos and toPos, // exclusive of oldPos, inclusive of newPos, for the rooms in which // the user has new membership events. // A list of joined room IDs is also returned in case the caller needs it. @@ -806,7 +806,7 @@ func (d *Database) GetStateDeltas( return deltas, joinedRoomIDs, nil } -// getStateDeltasForFullStateSync is a variant of getStateDeltas used for /sync +// GetStateDeltasForFullStateSync is a variant of getStateDeltas used for /sync // requests with full_state=true. // Fetches full state for all joined rooms and uses selectStateInRange to get // updates for other rooms. @@ -1033,37 +1033,45 @@ func (d *Database) GetUserUnreadNotificationCounts(ctx context.Context, userID s return d.NotificationData.SelectUserUnreadCounts(ctx, userID, from, to) } -func (s *Database) SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) { - return s.OutputEvents.SelectContextEvent(ctx, nil, roomID, eventID) +func (d *Database) SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) { + return d.OutputEvents.SelectContextEvent(ctx, nil, roomID, eventID) } -func (s *Database) SelectContextBeforeEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error) { - return s.OutputEvents.SelectContextBeforeEvent(ctx, nil, id, roomID, filter) +func (d *Database) SelectContextBeforeEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error) { + return d.OutputEvents.SelectContextBeforeEvent(ctx, nil, id, roomID, filter) } -func (s *Database) SelectContextAfterEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error) { - return s.OutputEvents.SelectContextAfterEvent(ctx, nil, id, roomID, filter) +func (d *Database) SelectContextAfterEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error) { + return d.OutputEvents.SelectContextAfterEvent(ctx, nil, id, roomID, filter) } -func (s *Database) IgnoresForUser(ctx context.Context, userID string) (*types.IgnoredUsers, error) { - return s.Ignores.SelectIgnores(ctx, userID) +func (d *Database) IgnoresForUser(ctx context.Context, userID string) (*types.IgnoredUsers, error) { + return d.Ignores.SelectIgnores(ctx, userID) } -func (s *Database) UpdateIgnoresForUser(ctx context.Context, userID string, ignores *types.IgnoredUsers) error { - return s.Ignores.UpsertIgnores(ctx, userID, ignores) +func (d *Database) UpdateIgnoresForUser(ctx context.Context, userID string, ignores *types.IgnoredUsers) error { + return d.Ignores.UpsertIgnores(ctx, userID, ignores) } -func (s *Database) UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (types.StreamPosition, error) { - return s.Presence.UpsertPresence(ctx, nil, userID, statusMsg, presence, lastActiveTS, fromSync) +func (d *Database) UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (types.StreamPosition, error) { + return d.Presence.UpsertPresence(ctx, nil, userID, statusMsg, presence, lastActiveTS, fromSync) } -func (s *Database) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) { - return s.Presence.GetPresenceForUser(ctx, nil, userID) +func (d *Database) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) { + return d.Presence.GetPresenceForUser(ctx, nil, userID) } -func (s *Database) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) { - return s.Presence.GetPresenceAfter(ctx, nil, after, filter) +func (d *Database) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) { + return d.Presence.GetPresenceAfter(ctx, nil, after, filter) } -func (s *Database) MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error) { - return s.Presence.GetMaxPresenceID(ctx, nil) +func (d *Database) MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error) { + return d.Presence.GetMaxPresenceID(ctx, nil) +} + +func (d *Database) SelectTopologicalEvent(ctx context.Context, topologicalPosition int, eventType, roomID string) (*gomatrixserverlib.HeaderedEvent, types.TopologyToken, error) { + return d.OutputEvents.SelectTopologicalEvent(ctx, nil, topologicalPosition, eventType, roomID) +} + +func (d *Database) SelectMembershipForUser(ctx context.Context, roomID, userID string, pos int) (membership string, topologicalPos int, err error) { + return d.Memberships.SelectMembershipForUser(ctx, nil, roomID, userID, pos) } diff --git a/syncapi/storage/sqlite3/memberships_table.go b/syncapi/storage/sqlite3/memberships_table.go index e4daa99c1..f5c887302 100644 --- a/syncapi/storage/sqlite3/memberships_table.go +++ b/syncapi/storage/sqlite3/memberships_table.go @@ -66,11 +66,15 @@ const selectMembershipCountSQL = "" + const selectHeroesSQL = "" + "SELECT DISTINCT user_id FROM syncapi_memberships WHERE room_id = $1 AND user_id != $2 AND membership IN ($3) LIMIT 5" +const selectMembershipBeforeSQL = "" + + "SELECT membership, topological_pos FROM syncapi_memberships WHERE room_id = $1 and user_id = $2 AND topological_pos <= $3 ORDER BY topological_pos DESC LIMIT 1" + type membershipsStatements struct { db *sql.DB upsertMembershipStmt *sql.Stmt selectMembershipCountStmt *sql.Stmt //selectHeroesStmt *sql.Stmt - prepared at runtime due to variadic + selectMembershipForUserStmt *sql.Stmt } func NewSqliteMembershipsTable(db *sql.DB) (tables.Memberships, error) { @@ -84,6 +88,7 @@ func NewSqliteMembershipsTable(db *sql.DB) (tables.Memberships, error) { return s, sqlutil.StatementList{ {&s.upsertMembershipStmt, upsertMembershipSQL}, {&s.selectMembershipCountStmt, selectMembershipCountSQL}, + {&s.selectMembershipForUserStmt, selectMembershipBeforeSQL}, // {&s.selectHeroesStmt, selectHeroesSQL}, - prepared at runtime due to variadic }.Prepare(db) } @@ -148,3 +153,20 @@ func (s *membershipsStatements) SelectHeroes( } return heroes, rows.Err() } + +// SelectMembershipForUser returns the membership of the user before and including the given position. If no membership can be found +// returns "leave", the topological position and no error. If an error occurs, other than sql.ErrNoRows, returns that and an empty +// string as the membership. +func (s *membershipsStatements) SelectMembershipForUser( + ctx context.Context, txn *sql.Tx, roomID, userID string, pos int, +) (membership string, topologyPos int, err error) { + stmt := sqlutil.TxStmt(txn, s.selectMembershipForUserStmt) + err = stmt.QueryRowContext(ctx, roomID, userID, pos).Scan(&membership, &topologyPos) + if err != nil { + if err == sql.ErrNoRows { + return "leave", 0, nil + } + return "", 0, err + } + return membership, topologyPos, nil +} diff --git a/syncapi/storage/sqlite3/output_room_events_table.go b/syncapi/storage/sqlite3/output_room_events_table.go index f9961a9d1..1c619d0fe 100644 --- a/syncapi/storage/sqlite3/output_room_events_table.go +++ b/syncapi/storage/sqlite3/output_room_events_table.go @@ -114,6 +114,13 @@ const selectContextAfterEventSQL = "" + // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters +const selectTopologicalEventSQL = "" + + "SELECT headered_event_json, topological_position, stream_position " + + " FROM syncapi_output_room_events_topology " + + " JOIN syncapi_output_room_events ON syncapi_output_room_events.event_id = syncapi_output_room_events_topology.event_id " + + " WHERE syncapi_output_room_events_topology.room_id = $1 AND topological_position < $2 AND type = $3 " + + " ORDER BY topological_position DESC LIMIT 1" + type outputRoomEventsStatements struct { db *sql.DB streamIDStatements *StreamIDStatements @@ -124,6 +131,7 @@ type outputRoomEventsStatements struct { selectContextEventStmt *sql.Stmt selectContextBeforeEventStmt *sql.Stmt selectContextAfterEventStmt *sql.Stmt + selectTopologicalEventStmt *sql.Stmt } func NewSqliteEventsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Events, error) { @@ -143,6 +151,7 @@ func NewSqliteEventsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Even {&s.selectContextEventStmt, selectContextEventSQL}, {&s.selectContextBeforeEventStmt, selectContextBeforeEventSQL}, {&s.selectContextAfterEventStmt, selectContextAfterEventSQL}, + {&s.selectTopologicalEventStmt, selectTopologicalEventSQL}, }.Prepare(db) } @@ -600,6 +609,31 @@ func (s *outputRoomEventsStatements) SelectContextAfterEvent( return lastID, evts, rows.Err() } +// SelectTopologicalEvent selects an event before and including the given position by eventType and roomID. Returns the found event and the topology token. +// If not event was found, returns nil and sql.ErrNoRows. +func (s *outputRoomEventsStatements) SelectTopologicalEvent( + ctx context.Context, txn *sql.Tx, topologicalPosition int, eventType, roomID string, +) (*gomatrixserverlib.HeaderedEvent, types.TopologyToken, error) { + var ( + eventBytes []byte + token types.TopologyToken + ) + err := sqlutil.TxStmtContext(ctx, txn, s.selectTopologicalEventStmt). + QueryRowContext(ctx, roomID, topologicalPosition, eventType). + Scan(&eventBytes, &token.Depth, &token.PDUPosition) + + if err != nil { + return nil, types.TopologyToken{}, err + } + + var res *gomatrixserverlib.HeaderedEvent + if err = json.Unmarshal(eventBytes, &res); err != nil { + return nil, types.TopologyToken{}, err + } + + return res, token, nil +} + func unmarshalStateIDs(addIDsJSON, delIDsJSON string) (addIDs []string, delIDs []string, err error) { if len(addIDsJSON) > 0 { if err = json.Unmarshal([]byte(addIDsJSON), &addIDs); err != nil { diff --git a/syncapi/storage/sqlite3/syncserver.go b/syncapi/storage/sqlite3/syncserver.go index e08a0ba82..9f56e5e42 100644 --- a/syncapi/storage/sqlite3/syncserver.go +++ b/syncapi/storage/sqlite3/syncserver.go @@ -56,6 +56,10 @@ func (d *SyncServerDatasource) prepare(dbProperties *config.DatabaseOptions) (er if err != nil { return err } + topology, err := NewSqliteTopologyTable(d.db) + if err != nil { + return err + } events, err := NewSqliteEventsTable(d.db, &d.streamID) if err != nil { return err @@ -72,10 +76,6 @@ func (d *SyncServerDatasource) prepare(dbProperties *config.DatabaseOptions) (er if err != nil { return err } - topology, err := NewSqliteTopologyTable(d.db) - if err != nil { - return err - } bwExtrem, err := NewSqliteBackwardsExtremitiesTable(d.db) if err != nil { return err diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index ccdebfdbd..e6ec797fe 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -67,6 +67,7 @@ type Events interface { SelectContextEvent(ctx context.Context, txn *sql.Tx, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) SelectContextBeforeEvent(ctx context.Context, txn *sql.Tx, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error) SelectContextAfterEvent(ctx context.Context, txn *sql.Tx, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error) + SelectTopologicalEvent(ctx context.Context, txn *sql.Tx, topologicalPosition int, eventType, roomID string) (*gomatrixserverlib.HeaderedEvent, types.TopologyToken, error) } // Topology keeps track of the depths and stream positions for all events. @@ -173,6 +174,7 @@ type Memberships interface { UpsertMembership(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, streamPos, topologicalPos types.StreamPosition) error SelectMembershipCount(ctx context.Context, txn *sql.Tx, roomID, membership string, pos types.StreamPosition) (count int, err error) SelectHeroes(ctx context.Context, txn *sql.Tx, roomID, userID string, memberships []string) (heroes []string, err error) + SelectMembershipForUser(ctx context.Context, txn *sql.Tx, roomID, userID string, pos int) (membership string, topologicalPos int, err error) } type NotificationData interface { diff --git a/syncapi/storage/tables/output_room_events_test.go b/syncapi/storage/tables/output_room_events_test.go index 69bbd04c9..8604cdebd 100644 --- a/syncapi/storage/tables/output_room_events_test.go +++ b/syncapi/storage/tables/output_room_events_test.go @@ -29,12 +29,20 @@ func newOutputRoomEventsTable(t *testing.T, dbType test.DBType) (tables.Events, var tab tables.Events switch dbType { case test.DBTypePostgres: + _, err = postgres.NewPostgresTopologyTable(db) // needed, since there is a join on it + if err != nil { + t.Fatalf("unable to create table: %s", err) + } tab, err = postgres.NewPostgresEventsTable(db) case test.DBTypeSQLite: var stream sqlite3.StreamIDStatements if err = stream.Prepare(db); err != nil { t.Fatalf("failed to prepare stream stmts: %s", err) } + _, err = sqlite3.NewSqliteTopologyTable(db) // needed, since there is a join on it + if err != nil { + t.Fatalf("unable to create table: %s", err) + } tab, err = sqlite3.NewSqliteEventsTable(db, &stream) } if err != nil {