diff --git a/src/csharp/Microsoft.Spark/Interop/SparkEnvironment.cs b/src/csharp/Microsoft.Spark/Interop/SparkEnvironment.cs index be13d7376..579293909 100644 --- a/src/csharp/Microsoft.Spark/Interop/SparkEnvironment.cs +++ b/src/csharp/Microsoft.Spark/Interop/SparkEnvironment.cs @@ -2,6 +2,8 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System; +using System.Dynamic; using Microsoft.Spark.Interop.Ipc; using Microsoft.Spark.Services; @@ -12,6 +14,18 @@ namespace Microsoft.Spark.Interop /// internal static class SparkEnvironment { + private static readonly Lazy s_sparkVersion = new Lazy( + () => new Version((string)JvmBridge.CallStaticJavaMethod( + "org.apache.spark.deploy.dotnet.DotnetRunner", + "SPARK_VERSION"))); + internal static Version SparkVersion + { + get + { + return s_sparkVersion.Value; + } + } + private static IJvmBridge s_jvmBridge; internal static IJvmBridge JvmBridge { diff --git a/src/csharp/Microsoft.Spark/Sql/DataFrame.cs b/src/csharp/Microsoft.Spark/Sql/DataFrame.cs index 40608cbf7..afb455f73 100644 --- a/src/csharp/Microsoft.Spark/Sql/DataFrame.cs +++ b/src/csharp/Microsoft.Spark/Sql/DataFrame.cs @@ -6,6 +6,7 @@ using System.Collections.Generic; using System.Linq; using System.Net; +using Microsoft.Spark.Interop; using Microsoft.Spark.Interop.Ipc; using Microsoft.Spark.Network; using Microsoft.Spark.Sql.Streaming; @@ -902,20 +903,26 @@ private IEnumerable GetRows(string funcName) /// A tuple of port number and secret string private (int, string) GetConnectionInfo(string funcName) { - var result = _jvmObject.Invoke(funcName); - if (result is int) + object result = _jvmObject.Invoke(funcName); + Version version = SparkEnvironment.SparkVersion; + return (version.Major, version.Minor, version.Build) switch { // In spark 2.3.0, PythonFunction.serveIterator() returns a port number. - return ((int)result, string.Empty); - } - else - { + (2, 3, 0) => ((int)result, string.Empty), // From spark >= 2.3.1, PythonFunction.serveIterator() returns a pair // where the first is a port number and the second is the secret // string to use for the authentication. - var pair = (JvmObjectReference[])result; - return ((int)pair[0].Invoke("intValue"), (string)pair[1].Invoke("toString")); - } + (2, 3, _) => ParseConnectionInfo(result), + (2, 4, _) => ParseConnectionInfo(result), + (3, 0, _) => ParseConnectionInfo(result), + _ => throw new NotSupportedException($"Spark {version} not supported.") + }; + } + + private (int, string) ParseConnectionInfo(object info) + { + var pair = (JvmObjectReference[])info; + return ((int)pair[0].Invoke("intValue"), (string)pair[1].Invoke("toString")); } private DataFrame WrapAsDataFrame(object obj) => new DataFrame((JvmObjectReference)obj);