From b500f2bfb675ff28be068bdeb23f87be279e9b56 Mon Sep 17 00:00:00 2001 From: Anant Prakash Date: Thu, 2 Aug 2018 22:58:41 +0530 Subject: [PATCH] GetJoinedHosts from federation server db --- .../storage/joined_hosts_table.go | 15 ++++++++++++++- .../dendrite/federationsender/storage/storage.go | 11 ++++++++++- 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/src/github.com/matrix-org/dendrite/federationsender/storage/joined_hosts_table.go b/src/github.com/matrix-org/dendrite/federationsender/storage/joined_hosts_table.go index 487de9e61..5d652a1a1 100644 --- a/src/github.com/matrix-org/dendrite/federationsender/storage/joined_hosts_table.go +++ b/src/github.com/matrix-org/dendrite/federationsender/storage/joined_hosts_table.go @@ -97,10 +97,22 @@ func (s *joinedHostsStatements) deleteJoinedHosts( return err } -func (s *joinedHostsStatements) selectJoinedHosts( +func (s *joinedHostsStatements) selectJoinedHostsWithTx( ctx context.Context, txn *sql.Tx, roomID string, ) ([]types.JoinedHost, error) { stmt := common.TxStmt(txn, s.selectJoinedHostsStmt) + return joinedHostsFromStmt(ctx, stmt, roomID) +} + +func (s *joinedHostsStatements) selectJoinedHosts( + ctx context.Context, roomID string, +) ([]types.JoinedHost, error) { + return joinedHostsFromStmt(ctx, s.selectJoinedHostsStmt, roomID) +} + +func joinedHostsFromStmt( + ctx context.Context, stmt *sql.Stmt, roomID string, +) ([]types.JoinedHost, error) { rows, err := stmt.QueryContext(ctx, roomID) if err != nil { return nil, err @@ -118,5 +130,6 @@ func (s *joinedHostsStatements) selectJoinedHosts( ServerName: gomatrixserverlib.ServerName(serverName), }) } + return result, nil } diff --git a/src/github.com/matrix-org/dendrite/federationsender/storage/storage.go b/src/github.com/matrix-org/dendrite/federationsender/storage/storage.go index e84d639d0..3a0f87752 100644 --- a/src/github.com/matrix-org/dendrite/federationsender/storage/storage.go +++ b/src/github.com/matrix-org/dendrite/federationsender/storage/storage.go @@ -92,7 +92,7 @@ func (d *Database) UpdateRoom( } } - joinedHosts, err = d.selectJoinedHosts(ctx, txn, roomID) + joinedHosts, err = d.selectJoinedHostsWithTx(ctx, txn, roomID) if err != nil { return err } @@ -110,3 +110,12 @@ func (d *Database) UpdateRoom( }) return } + +// GetJoinedHosts returns the currently joined hosts for room, +// as known to federationserver. +// Returns an error if something goes wrong. +func (d *Database) GetJoinedHosts( + ctx context.Context, roomID string, +) ([]types.JoinedHost, error) { + return d.selectJoinedHosts(ctx, roomID) +}