2020-01-03 08:07:05 -06:00
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
2017-04-20 17:40:52 -05:00
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
2020-01-03 08:07:05 -06:00
package postgres
2017-02-15 08:43:19 -06:00
import (
2017-09-13 10:30:19 -05:00
"context"
2017-02-15 08:43:19 -06:00
"database/sql"
"fmt"
2017-08-07 05:51:46 -05:00
2017-02-15 08:43:19 -06:00
"github.com/lib/pq"
2023-02-07 07:31:23 -06:00
"github.com/matrix-org/gomatrixserverlib"
2022-10-11 09:04:02 -05:00
"github.com/matrix-org/util"
2020-09-24 05:10:14 -05:00
"github.com/matrix-org/dendrite/internal/sqlutil"
2017-02-15 08:43:19 -06:00
"github.com/matrix-org/dendrite/roomserver/types"
)
const stateSnapshotSchema = `
-- The state of a room before an event .
-- Stored as a list of state_block entries stored in a separate table .
-- The actual state is constructed by combining all the state_block entries
-- referenced by state_block_nids together . If the same state key tuple appears
-- multiple times then the entry from the later state_block clobbers the earlier
-- entries .
-- This encoding format allows us to implement a delta encoding which is useful
-- because room state tends to accumulate small changes over time . Although if
-- the list of deltas becomes too long it becomes more efficient to encode
-- the full state under single state_block_nid .
2017-08-07 05:51:46 -05:00
CREATE SEQUENCE IF NOT EXISTS roomserver_state_snapshot_nid_seq ;
CREATE TABLE IF NOT EXISTS roomserver_state_snapshots (
2021-04-26 07:25:57 -05:00
-- The state snapshot NID that identifies this snapshot .
state_snapshot_nid bigint PRIMARY KEY DEFAULT nextval ( ' roomserver_state_snapshot_nid_seq ' ) ,
-- The hash of the state snapshot , which is used to enforce uniqueness . The hash is
-- generated in Dendrite and passed through to the database , as a btree index over
-- this column is cheap and fits within the maximum index size .
state_snapshot_hash BYTEA UNIQUE ,
-- The room NID that the snapshot belongs to .
room_nid bigint NOT NULL ,
-- The state blocks contained within this snapshot .
state_block_nids bigint [ ] NOT NULL
2017-02-15 08:43:19 -06:00
) ;
`
2021-04-26 07:25:57 -05:00
// Insert a new state snapshot. If we conflict on the hash column then
// we must perform an update so that the RETURNING statement returns the
// ID of the row that we conflicted with, so that we can then refer to
// the original snapshot.
2017-02-15 08:43:19 -06:00
const insertStateSQL = "" +
2021-04-26 07:25:57 -05:00
"INSERT INTO roomserver_state_snapshots (state_snapshot_hash, room_nid, state_block_nids)" +
" VALUES ($1, $2, $3)" +
" ON CONFLICT (state_snapshot_hash) DO UPDATE SET room_nid=$2" +
// Performing an update, above, ensures that the RETURNING statement
// below will always return a valid state snapshot ID
2017-02-15 08:43:19 -06:00
" RETURNING state_snapshot_nid"
// Bulk state data NID lookup.
// Sorting by state_snapshot_nid means we can use binary search over the result
// to lookup the state data NIDs for a state snapshot NID.
const bulkSelectStateBlockNIDsSQL = "" +
2017-08-07 05:51:46 -05:00
"SELECT state_snapshot_nid, state_block_nids FROM roomserver_state_snapshots" +
2017-02-15 08:43:19 -06:00
" WHERE state_snapshot_nid = ANY($1) ORDER BY state_snapshot_nid ASC"
2022-08-01 08:11:00 -05:00
// Looks up both the history visibility event and relevant membership events from
// a given domain name from a given state snapshot. This is used to optimise the
// helpers.CheckServerAllowedToSeeEvent function.
// TODO: There's a sequence scan here because of the hash join strategy, which is
// probably O(n) on state key entries, so there must be a way to avoid that somehow.
// Event type NIDs are:
// - 5: m.room.member as per https://github.com/matrix-org/dendrite/blob/c7f7aec4d07d59120d37d5b16a900f6d608a75c4/roomserver/storage/postgres/event_types_table.go#L40
// - 7: m.room.history_visibility as per https://github.com/matrix-org/dendrite/blob/c7f7aec4d07d59120d37d5b16a900f6d608a75c4/roomserver/storage/postgres/event_types_table.go#L42
const bulkSelectStateForHistoryVisibilitySQL = `
SELECT event_nid FROM (
SELECT event_nid , event_type_nid , event_state_key_nid FROM roomserver_events
WHERE ( event_type_nid = 5 OR event_type_nid = 7 )
AND event_nid = ANY (
SELECT UNNEST ( event_nids ) FROM roomserver_state_block
WHERE state_block_nid = ANY (
SELECT UNNEST ( state_block_nids ) FROM roomserver_state_snapshots
WHERE state_snapshot_nid = $ 1
)
)
2022-10-11 09:04:02 -05:00
ORDER BY depth ASC
2022-08-01 08:11:00 -05:00
) AS roomserver_events
INNER JOIN roomserver_event_state_keys
ON roomserver_events . event_state_key_nid = roomserver_event_state_keys . event_state_key_nid
AND ( event_type_nid = 7 OR event_state_key LIKE ' % : ' || $ 2 ) ;
`
2023-02-07 07:31:23 -06:00
// bulkSelectMembershipForHistoryVisibilitySQL is an optimization to get membership events for a specific user for defined set of events.
// Returns the event_id of the event we want the membership event for, the event_id of the membership event and the membership event JSON.
const bulkSelectMembershipForHistoryVisibilitySQL = `
SELECT re . event_id , re2 . event_id , rej . event_json
FROM roomserver_events re
LEFT JOIN roomserver_state_snapshots rss on re . state_snapshot_nid = rss . state_snapshot_nid
CROSS JOIN unnest ( rss . state_block_nids ) AS blocks ( block_nid )
LEFT JOIN roomserver_state_block rsb ON rsb . state_block_nid = blocks . block_nid
CROSS JOIN unnest ( rsb . event_nids ) AS rsb2 ( event_nid )
JOIN roomserver_events re2 ON re2 . room_nid = $ 3 AND re2 . event_type_nid = 5 AND re2 . event_nid = rsb2 . event_nid AND re2 . event_state_key_nid = $ 1
LEFT JOIN roomserver_event_json rej ON rej . event_nid = re2 . event_nid
WHERE re . event_id = ANY ( $ 2 )
`
2017-02-15 08:43:19 -06:00
type stateSnapshotStatements struct {
2023-02-07 07:31:23 -06:00
insertStateStmt * sql . Stmt
bulkSelectStateBlockNIDsStmt * sql . Stmt
bulkSelectStateForHistoryVisibilityStmt * sql . Stmt
bulktSelectMembershipForHistoryVisibilityStmt * sql . Stmt
2017-02-15 08:43:19 -06:00
}
2022-05-16 12:33:16 -05:00
func CreateStateSnapshotTable ( db * sql . DB ) error {
2020-05-27 03:36:09 -05:00
_ , err := db . Exec ( stateSnapshotSchema )
2021-04-26 07:25:57 -05:00
return err
}
2023-02-07 07:31:23 -06:00
func PrepareStateSnapshotTable ( db * sql . DB ) ( * stateSnapshotStatements , error ) {
2021-04-26 07:25:57 -05:00
s := & stateSnapshotStatements { }
2017-03-07 04:37:41 -06:00
2021-07-28 12:30:04 -05:00
return s , sqlutil . StatementList {
2017-03-07 04:37:41 -06:00
{ & s . insertStateStmt , insertStateSQL } ,
{ & s . bulkSelectStateBlockNIDsStmt , bulkSelectStateBlockNIDsSQL } ,
2022-08-01 08:11:00 -05:00
{ & s . bulkSelectStateForHistoryVisibilityStmt , bulkSelectStateForHistoryVisibilitySQL } ,
2023-02-07 07:31:23 -06:00
{ & s . bulktSelectMembershipForHistoryVisibilityStmt , bulkSelectMembershipForHistoryVisibilitySQL } ,
2020-05-27 05:03:47 -05:00
} . Prepare ( db )
2017-02-15 08:43:19 -06:00
}
2020-05-27 03:36:09 -05:00
func ( s * stateSnapshotStatements ) InsertState (
2021-04-26 07:25:57 -05:00
ctx context . Context , txn * sql . Tx , roomNID types . RoomNID , nids types . StateBlockNIDs ,
2017-09-13 10:30:19 -05:00
) ( stateNID types . StateSnapshotNID , err error ) {
2021-04-26 07:25:57 -05:00
nids = nids [ : util . SortAndUnique ( nids ) ]
2022-05-16 12:33:16 -05:00
err = sqlutil . TxStmt ( txn , s . insertStateStmt ) . QueryRowContext ( ctx , nids . Hash ( ) , int64 ( roomNID ) , stateBlockNIDsAsArray ( nids ) ) . Scan ( & stateNID )
2021-04-26 07:25:57 -05:00
if err != nil {
return 0 , err
2017-02-15 08:43:19 -06:00
}
return
}
2020-05-27 03:36:09 -05:00
func ( s * stateSnapshotStatements ) BulkSelectStateBlockNIDs (
2022-02-04 04:39:34 -06:00
ctx context . Context , txn * sql . Tx , stateNIDs [ ] types . StateSnapshotNID ,
2017-09-13 10:30:19 -05:00
) ( [ ] types . StateBlockNIDList , error ) {
2017-02-15 08:43:19 -06:00
nids := make ( [ ] int64 , len ( stateNIDs ) )
for i := range stateNIDs {
nids [ i ] = int64 ( stateNIDs [ i ] )
}
2022-02-04 04:39:34 -06:00
stmt := sqlutil . TxStmt ( txn , s . bulkSelectStateBlockNIDsStmt )
rows , err := stmt . QueryContext ( ctx , pq . Int64Array ( nids ) )
2017-02-15 08:43:19 -06:00
if err != nil {
return nil , err
}
2017-09-20 04:59:19 -05:00
defer rows . Close ( ) // nolint: errcheck
2017-02-15 08:43:19 -06:00
results := make ( [ ] types . StateBlockNIDList , len ( stateNIDs ) )
i := 0
2022-05-16 12:33:16 -05:00
var stateBlockNIDs pq . Int64Array
2017-02-15 08:43:19 -06:00
for ; rows . Next ( ) ; i ++ {
result := & results [ i ]
2020-02-11 08:12:21 -06:00
if err = rows . Scan ( & result . StateSnapshotNID , & stateBlockNIDs ) ; err != nil {
2017-02-15 08:43:19 -06:00
return nil , err
}
result . StateBlockNIDs = make ( [ ] types . StateBlockNID , len ( stateBlockNIDs ) )
for k := range stateBlockNIDs {
result . StateBlockNIDs [ k ] = types . StateBlockNID ( stateBlockNIDs [ k ] )
}
}
2020-02-11 08:12:21 -06:00
if err = rows . Err ( ) ; err != nil {
return nil , err
}
2017-02-15 08:43:19 -06:00
if i != len ( stateNIDs ) {
2022-02-21 10:22:29 -06:00
return nil , types . MissingStateError ( fmt . Sprintf ( "storage: state NIDs missing from the database (%d != %d)" , i , len ( stateNIDs ) ) )
2017-02-15 08:43:19 -06:00
}
return results , nil
}
2022-08-01 08:11:00 -05:00
func ( s * stateSnapshotStatements ) BulkSelectStateForHistoryVisibility (
ctx context . Context , txn * sql . Tx , stateSnapshotNID types . StateSnapshotNID , domain string ,
) ( [ ] types . EventNID , error ) {
stmt := sqlutil . TxStmt ( txn , s . bulkSelectStateForHistoryVisibilityStmt )
rows , err := stmt . QueryContext ( ctx , stateSnapshotNID , domain )
if err != nil {
return nil , err
}
defer rows . Close ( ) // nolint: errcheck
results := make ( [ ] types . EventNID , 0 , 16 )
for rows . Next ( ) {
var eventNID types . EventNID
if err = rows . Scan ( & eventNID ) ; err != nil {
return nil , err
}
results = append ( results , eventNID )
}
return results , rows . Err ( )
}
2023-02-07 07:31:23 -06:00
func ( s * stateSnapshotStatements ) BulkSelectMembershipForHistoryVisibility (
ctx context . Context , txn * sql . Tx , userNID types . EventStateKeyNID , roomInfo * types . RoomInfo , eventIDs ... string ,
) ( map [ string ] * gomatrixserverlib . HeaderedEvent , error ) {
stmt := sqlutil . TxStmt ( txn , s . bulktSelectMembershipForHistoryVisibilityStmt )
rows , err := stmt . QueryContext ( ctx , userNID , pq . Array ( eventIDs ) , roomInfo . RoomNID )
if err != nil {
return nil , err
}
defer rows . Close ( ) // nolint: errcheck
result := make ( map [ string ] * gomatrixserverlib . HeaderedEvent , len ( eventIDs ) )
var evJson [ ] byte
var eventID string
var membershipEventID string
knownEvents := make ( map [ string ] * gomatrixserverlib . HeaderedEvent , len ( eventIDs ) )
2023-04-21 11:06:29 -05:00
verImpl , err := gomatrixserverlib . GetRoomVersion ( roomInfo . RoomVersion )
if err != nil {
return nil , err
}
2023-02-07 07:31:23 -06:00
for rows . Next ( ) {
if err = rows . Scan ( & eventID , & membershipEventID , & evJson ) ; err != nil {
return nil , err
}
if len ( evJson ) == 0 {
result [ eventID ] = & gomatrixserverlib . HeaderedEvent { }
continue
}
// If we already know this event, don't try to marshal the json again
if ev , ok := knownEvents [ membershipEventID ] ; ok {
result [ eventID ] = ev
continue
}
2023-04-21 11:06:29 -05:00
event , err := verImpl . NewEventFromTrustedJSON ( evJson , false )
2023-02-07 07:31:23 -06:00
if err != nil {
result [ eventID ] = & gomatrixserverlib . HeaderedEvent { }
// not fatal
continue
}
he := event . Headered ( roomInfo . RoomVersion )
result [ eventID ] = he
knownEvents [ membershipEventID ] = he
}
return result , rows . Err ( )
}