diff --git a/federationsender/api/api.go b/federationsender/api/api.go index dc0856723..04ac63462 100644 --- a/federationsender/api/api.go +++ b/federationsender/api/api.go @@ -114,7 +114,7 @@ type PerformJoinResponse struct { type PerformPeekRequest struct { RoomID string `json:"room_id"` // The sorted list of servers to try. Servers will be tried sequentially, after de-duplication. - ServerNames types.ServerNames `json:"server_names"` + ServerNames types.ServerNames `json:"server_names"` } type PerformPeekResponse struct { diff --git a/federationsender/internal/perform.go b/federationsender/internal/perform.go index 92dfa89a5..14df6c861 100644 --- a/federationsender/internal/perform.go +++ b/federationsender/internal/perform.go @@ -186,7 +186,7 @@ func (r *FederationSenderInternalAPI) performJoinUsingServer( // Check that the send_join response was valid. joinCtx := perform.JoinContext(r.federation, r.keyRing) respState, err := joinCtx.CheckSendJoinResponse( - ctx, event, serverName, respMakeJoin, respSendJoin, + ctx, event, serverName, respSendJoin, ) if err != nil { return fmt.Errorf("joinCtx.CheckSendJoinResponse: %w", err) @@ -195,10 +195,11 @@ func (r *FederationSenderInternalAPI) performJoinUsingServer( // If we successfully performed a send_join above then the other // server now thinks we're a part of the room. Send the newly // returned state to the roomserver to update our local view. + headeredEvent := event.Headered(respMakeJoin.RoomVersion) if err = roomserverAPI.SendEventWithState( ctx, r.rsAPI, respState, - event.Headered(respMakeJoin.RoomVersion), + &headeredEvent, nil, ); err != nil { return fmt.Errorf("r.producer.SendEventWithState: %w", err) @@ -293,18 +294,17 @@ func (r *FederationSenderInternalAPI) performPeekUsingServer( // check whether we're peeking already to try to avoid needlessly // re-peeking on the server. we don't need a transaction for this, // given this is a nice-to-have. - remotePeek, err := r.db.GetRemotePeek(ctx, roomID, serverName, peekID) + remotePeek, err := r.db.GetRemotePeek(ctx, serverName, roomID, peekID) if err != nil { return err } renewing := false if remotePeek != nil { nowMilli := time.Now().UnixNano() / int64(time.Millisecond) - if (nowMilli > remotePeek.RenewedTimestamp + remotePeek.RenewalInterval) { + if nowMilli > remotePeek.RenewedTimestamp+remotePeek.RenewalInterval { logrus.Infof("stale remote peek to %s for %s already exists; renewing", serverName, roomID) renewing = true - } - else { + } else { logrus.Infof("live remote peek to %s for %s already exists", serverName, roomID) return nil } @@ -336,12 +336,11 @@ func (r *FederationSenderInternalAPI) performPeekUsingServer( // If we've got this far, the remote server is peeking. if renewing { - if err = r.db.RenewRemotePeek(ctx, serverName, roomID, respPeek.RenewalInterval); err != nil { + if err = r.db.RenewRemotePeek(ctx, serverName, roomID, peekID, respPeek.RenewalInterval); err != nil { return err } - } - else { - if err = r.db.AddRemotePeek(ctx, serverName, roomID, respPeek.RenewalInterval); err != nil { + } else { + if err = r.db.AddRemotePeek(ctx, serverName, roomID, peekID, respPeek.RenewalInterval); err != nil { return err } } @@ -351,7 +350,7 @@ func (r *FederationSenderInternalAPI) performPeekUsingServer( if err = roomserverAPI.SendEventWithState( ctx, r.rsAPI, &respState, - event.Headered(respPeek.RoomVersion), nil, + nil, nil, ); err != nil { return fmt.Errorf("r.producer.SendEventWithState: %w", err) } diff --git a/federationsender/storage/interface.go b/federationsender/storage/interface.go index 75de07f4d..49d3937bf 100644 --- a/federationsender/storage/interface.go +++ b/federationsender/storage/interface.go @@ -50,12 +50,13 @@ type Database interface { GetPendingPDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) GetPendingEDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) + // XXX: why don't these have contexts passed in? AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error RemoveServerFromBlacklist(serverName gomatrixserverlib.ServerName) error IsServerBlacklisted(serverName gomatrixserverlib.ServerName) (bool, error) - AddRemotePeek(serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int) error - RenewRemotePeek(serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int) error - GetRemotePeek(serverName gomatrixserverlib.ServerName, roomID, peekID string) (types.RemotePeek, error) - GetRemotePeeks(roomID string) ([]types.RemotePeek, error) + AddRemotePeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int) error + RenewRemotePeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int) error + GetRemotePeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string) (*types.RemotePeek, error) + GetRemotePeeks(ctx context.Context, roomID string) ([]types.RemotePeek, error) } diff --git a/federationsender/storage/shared/storage.go b/federationsender/storage/shared/storage.go index 4fa66ab0a..4985c5eda 100644 --- a/federationsender/storage/shared/storage.go +++ b/federationsender/storage/shared/storage.go @@ -165,22 +165,22 @@ func (d *Database) IsServerBlacklisted(serverName gomatrixserverlib.ServerName) return d.FederationSenderBlacklist.SelectBlacklist(context.TODO(), nil, serverName) } -func (d *Database) AddRemotePeek(serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int) error { +func (d *Database) AddRemotePeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - return d.FederationSenderRemotePeeks.InsertRemotePeek(context.TODO(), txn, serverName, roomID, peekID, renewalInterval) + return d.FederationSenderRemotePeeks.InsertRemotePeek(ctx, txn, serverName, roomID, peekID, renewalInterval) }) } -func (d *Database) RenewRemotePeek(serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int) error { +func (d *Database) RenewRemotePeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - return d.FederationSenderRemotePeeks.RenewRemotePeek(context.TODO(), txn, serverName, roomID, peekID, renewalInterval) + return d.FederationSenderRemotePeeks.RenewRemotePeek(ctx, txn, serverName, roomID, peekID, renewalInterval) }) } -func (d *Database) GetRemotePeek(serverName gomatrixserverlib.ServerName, roomID, peekID string) (types.RemotePeek, error) { - return d.FederationSenderRemotePeeks.SelectRemotePeek(context.TODO(), serverName, roomID, peekID) +func (d *Database) GetRemotePeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string) (*types.RemotePeek, error) { + return d.FederationSenderRemotePeeks.SelectRemotePeek(ctx, nil, serverName, roomID, peekID) } -func (d *Database) GetRemotePeeks(roomID string) ([]types.RemotePeek, error) { - return d.FederationSenderRemotePeeks.SelectRemotePeeks(context.TODO(), roomID) +func (d *Database) GetRemotePeeks(ctx context.Context, roomID string) ([]types.RemotePeek, error) { + return d.FederationSenderRemotePeeks.SelectRemotePeeks(ctx, nil, roomID) } diff --git a/federationsender/storage/sqlite3/remote_peeks_table.go b/federationsender/storage/sqlite3/remote_peeks_table.go index 5b05cded1..19eef880c 100644 --- a/federationsender/storage/sqlite3/remote_peeks_table.go +++ b/federationsender/storage/sqlite3/remote_peeks_table.go @@ -17,7 +17,10 @@ package sqlite3 import ( "context" "database/sql" + "time" + "github.com/matrix-org/dendrite/federationsender/types" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" ) @@ -97,7 +100,7 @@ func (s *remotePeeksStatements) InsertRemotePeek( ) (err error) { nowMilli := time.Now().UnixNano() / int64(time.Millisecond) stmt := sqlutil.TxStmt(txn, s.insertRemotePeekStmt) - _, err := stmt.ExecContext(ctx, roomID, serverName, peekID, nowMilli, nowMilli, renewalInterval) + _, err = stmt.ExecContext(ctx, roomID, serverName, peekID, nowMilli, nowMilli, renewalInterval) return } @@ -105,31 +108,34 @@ func (s *remotePeeksStatements) RenewRemotePeek( ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int, ) (err error) { nowMilli := time.Now().UnixNano() / int64(time.Millisecond) - _, err := sqlutil.TxStmt(txn, s.renewRemotePeekStmt).ExecContext(ctx, nowMilli, renewalInterval, roomID, serverName, peekID) + _, err = sqlutil.TxStmt(txn, s.renewRemotePeekStmt).ExecContext(ctx, nowMilli, renewalInterval, roomID, serverName, peekID) return } - func (s *remotePeeksStatements) SelectRemotePeek( ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, -) (remotePeek types.RemotePeek, err error) { +) (*types.RemotePeek, error) { rows, err := sqlutil.TxStmt(txn, s.selectRemotePeeksStmt).QueryContext(ctx, roomID) if err != nil { - return + return nil, err } defer internal.CloseAndLogIfError(ctx, rows, "SelectRemotePeek: rows.close() failed") remotePeek := types.RemotePeek{} - if err = rows.Scan( + err = rows.Scan( &remotePeek.RoomID, &remotePeek.ServerName, &remotePeek.PeekID, &remotePeek.CreationTimestamp, - &remotePeek.RenewTimestamp, + &remotePeek.RenewedTimestamp, &remotePeek.RenewalInterval, - ); err != nil { - return + ) + if err == sql.ErrNoRows { + return nil, nil } - return remotePeek, rows.Err() + if err != nil { + return nil, err + } + return &remotePeek, rows.Err() } func (s *remotePeeksStatements) SelectRemotePeeks( @@ -148,7 +154,7 @@ func (s *remotePeeksStatements) SelectRemotePeeks( &remotePeek.ServerName, &remotePeek.PeekID, &remotePeek.CreationTimestamp, - &remotePeek.RenewTimestamp, + &remotePeek.RenewedTimestamp, &remotePeek.RenewalInterval, ); err != nil { return @@ -162,13 +168,13 @@ func (s *remotePeeksStatements) SelectRemotePeeks( func (s *remotePeeksStatements) DeleteRemotePeek( ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, ) (err error) { - _, err := sqlutil.TxStmt(txn, s.deleteRemotePeekStmt).ExecContext(ctx, roomID, serverName, peekID) + _, err = sqlutil.TxStmt(txn, s.deleteRemotePeekStmt).ExecContext(ctx, roomID, serverName, peekID) return } func (s *remotePeeksStatements) DeleteRemotePeeks( ctx context.Context, txn *sql.Tx, roomID string, ) (err error) { - _, err := sqlutil.TxStmt(txn, s.deleteRemotePeeksStmt).ExecContext(ctx, roomID) + _, err = sqlutil.TxStmt(txn, s.deleteRemotePeeksStmt).ExecContext(ctx, roomID) return } diff --git a/federationsender/storage/tables/interface.go b/federationsender/storage/tables/interface.go index 85a6ba98b..e35eac0fb 100644 --- a/federationsender/storage/tables/interface.go +++ b/federationsender/storage/tables/interface.go @@ -71,7 +71,7 @@ type FederationSenderBlacklist interface { type FederationSenderRemotePeeks interface { InsertRemotePeek(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int) (err error) RenewRemotePeek(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int) (err error) - SelectRemotePeek(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string) (remotePeek types.RemotePeek, err error) + SelectRemotePeek(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string) (remotePeek *types.RemotePeek, err error) SelectRemotePeeks(ctx context.Context, txn *sql.Tx, roomID string) (remotePeeks []types.RemotePeek, err error) DeleteRemotePeek(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string) (err error) DeleteRemotePeeks(ctx context.Context, txn *sql.Tx, roomID string) (err error) diff --git a/federationsender/types/types.go b/federationsender/types/types.go index 08182262a..d59327796 100644 --- a/federationsender/types/types.go +++ b/federationsender/types/types.go @@ -50,14 +50,11 @@ func (e EventIDMismatchError) Error() string { ) } -// UnixMs is the milliseconds since the Unix epoch -type UnixMs int64 - type RemotePeek struct { - PeekID string - RoomID string - ServerName gomatrixserverlib.ServerName - CreatedTimestamp UnixMs - RenewedTimestamp UnixMs - RenewalInterval UnixMs -} \ No newline at end of file + PeekID string + RoomID string + ServerName gomatrixserverlib.ServerName + CreationTimestamp int64 + RenewedTimestamp int64 + RenewalInterval int64 +}