From 2f7d4cb2d2ae28e98ba7c4408f6f85703d5b2ee2 Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Tue, 25 Apr 2023 15:19:12 -0600 Subject: [PATCH] Refactor PerformJoin to use the new gmsl interface --- federationapi/api/api.go | 1 + federationapi/federationapi_test.go | 2 ++ federationapi/internal/federationclient.go | 30 ++++++++++++++++++++++ federationapi/internal/perform.go | 4 +-- 4 files changed, 35 insertions(+), 2 deletions(-) diff --git a/federationapi/api/api.go b/federationapi/api/api.go index 0048b4b04..c223f5045 100644 --- a/federationapi/api/api.go +++ b/federationapi/api/api.go @@ -16,6 +16,7 @@ import ( // FederationInternalAPI is used to query information from the federation sender. type FederationInternalAPI interface { gomatrixserverlib.FederatedStateClient + gomatrixserverlib.FederatedJoinClient KeyserverFederationAPI gomatrixserverlib.KeyDatabase ClientFederationAPI diff --git a/federationapi/federationapi_test.go b/federationapi/federationapi_test.go index cb4684858..46b67aa21 100644 --- a/federationapi/federationapi_test.go +++ b/federationapi/federationapi_test.go @@ -107,6 +107,8 @@ func (f *fedClient) GetServerKeys(ctx context.Context, matrixServer spec.ServerN } func (f *fedClient) MakeJoin(ctx context.Context, origin, s spec.ServerName, roomID, userID string) (res fclient.RespMakeJoin, err error) { + f.fedClientMutex.Lock() + defer f.fedClientMutex.Unlock() for _, r := range f.allowJoins { if r.ID == roomID { res.RoomVersion = r.Version diff --git a/federationapi/internal/federationclient.go b/federationapi/internal/federationclient.go index e4288a20c..957166bd8 100644 --- a/federationapi/internal/federationclient.go +++ b/federationapi/internal/federationclient.go @@ -12,6 +12,36 @@ import ( // Functions here are "proxying" calls to the gomatrixserverlib federation // client. +func (a *FederationInternalAPI) MakeJoin( + ctx context.Context, origin, s spec.ServerName, roomID, userID string, +) (res gomatrixserverlib.MakeJoinResponse, err error) { + ctx, cancel := context.WithTimeout(ctx, time.Second*30) + defer cancel() + ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { + return a.federation.MakeJoin(ctx, origin, s, roomID, userID) + }) + if err != nil { + return &fclient.RespMakeJoin{}, err + } + r := ires.(fclient.RespMakeJoin) + return &r, nil +} + +func (a *FederationInternalAPI) SendJoin( + ctx context.Context, origin, s spec.ServerName, event *gomatrixserverlib.Event, +) (res gomatrixserverlib.SendJoinResponse, err error) { + ctx, cancel := context.WithTimeout(ctx, time.Second*30) + defer cancel() + ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { + return a.federation.SendJoin(ctx, origin, s, event) + }) + if err != nil { + return &fclient.RespSendJoin{}, err + } + r := ires.(fclient.RespSendJoin) + return &r, nil +} + func (a *FederationInternalAPI) GetEventAuth( ctx context.Context, origin, s spec.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string, diff --git a/federationapi/internal/perform.go b/federationapi/internal/perform.go index d3c164afd..1eeae380f 100644 --- a/federationapi/internal/perform.go +++ b/federationapi/internal/perform.go @@ -150,7 +150,7 @@ func (r *FederationInternalAPI) performJoinUsingServer( return err } - joinInput := fclient.PerformJoinInput{ + joinInput := gomatrixserverlib.PerformJoinInput{ UserID: user, RoomID: roomID, ServerName: serverName, @@ -161,7 +161,7 @@ func (r *FederationInternalAPI) performJoinUsingServer( KeyRing: r.keyRing, EventProvider: federatedEventProvider(ctx, r.federation, r.keyRing, user.Domain(), serverName), } - response, joinErr := fclient.PerformJoin(ctx, r.federation, joinInput) + response, joinErr := gomatrixserverlib.PerformJoin(ctx, r, joinInput) if err != nil { if !joinErr.Reachable {