diff --git a/internal/cosmosdbapi/client.go b/internal/cosmosdbapi/client.go index eeae75573..e6a271d3f 100644 --- a/internal/cosmosdbapi/client.go +++ b/internal/cosmosdbapi/client.go @@ -2,7 +2,9 @@ package cosmosdbapi import ( "context" + "crypto/tls" "errors" + "net/http" "strings" "time" @@ -10,21 +12,30 @@ import ( ) type CosmosConnection struct { - Url string - Key string + Url string + Key string + DisableCertificateValidation bool } -func GetCosmosConnection(accountEndpoint string, accountKey string) CosmosConnection { +func GetCosmosConnection(accountEndpoint string, accountKey string, disableCertificateValidation bool) CosmosConnection { return CosmosConnection{ - Url: accountEndpoint, - Key: accountKey, + Url: accountEndpoint, + Key: accountKey, + DisableCertificateValidation: disableCertificateValidation, } } +func disableCertificateValidation() { + http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true} +} + func GetClient(conn CosmosConnection) *cosmosapi.Client { cfg := cosmosapi.Config{ MasterKey: conn.Key, } + if conn.DisableCertificateValidation { + disableCertificateValidation() + } return cosmosapi.New(conn.Url, cfg, nil, nil) } diff --git a/internal/cosmosdbutil/connection.go b/internal/cosmosdbutil/connection.go index 3b7bb8db2..2e078e5aa 100644 --- a/internal/cosmosdbutil/connection.go +++ b/internal/cosmosdbutil/connection.go @@ -1,6 +1,7 @@ package cosmosdbutil import ( + "strconv" "strings" "github.com/matrix-org/dendrite/internal/cosmosdbapi" @@ -12,10 +13,10 @@ const accountKeyName = "AccountKey" const databaseName = "DatabaseName" const containerName = "ContainerName" const tenantName = "TenantName" +const disableCertificateValidationName = "DisableCertificateValidation" func getConnectionString(d *config.DataSource) config.DataSource { - var connString string - connString = string(*d) + connString := string(*d) return config.DataSource(strings.Replace(connString, "cosmosdb:", "", 1)) } @@ -36,7 +37,15 @@ func GetCosmosConnection(d *config.DataSource) cosmosdbapi.CosmosConnection { connMap := getConnectionProperties(string(connString)) accountEndpoint := connMap[accountEndpointName] accountKey := connMap[accountKeyName] - return cosmosdbapi.GetCosmosConnection(accountEndpoint, accountKey) + value, ok := connMap[disableCertificateValidationName] + disableCertificateValidation := false + if ok { + valueBool, err := strconv.ParseBool(value) + if err == nil { + disableCertificateValidation = valueBool + } + } + return cosmosdbapi.GetCosmosConnection(accountEndpoint, accountKey, disableCertificateValidation) } func GetCosmosConfig(d *config.DataSource) cosmosdbapi.CosmosConfig {