From 2cf4f64731d5b53766b302e3bc1c01c7c28ea3fa Mon Sep 17 00:00:00 2001 From: Shyam Sai Date: Sun, 5 Jan 2025 22:51:43 -0600 Subject: [PATCH] Rename error column and move to back of df --- .../ml/services/openai/OpenAIPrompt.scala | 21 ++++++++++++------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala index 00fcdcf944..cb03ba7756 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala @@ -15,7 +15,7 @@ import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap, ParamValidators import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.{Column, DataFrame, Dataset, Row, functions => F, types => T} import org.apache.spark.sql.catalyst.encoders.RowEncoder -import org.apache.spark.sql.functions.udf +import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{DataType, StructType} import scala.collection.JavaConverters._ @@ -111,7 +111,7 @@ class OpenAIPrompt(override val uid: String) extends Transformer postProcessing -> "", postProcessingOptions -> Map.empty, outputCol -> (this.uid + "_output"), - errorCol -> (this.uid + "_error"), + errorCol -> "gen_error", messagesCol -> (this.uid + "_messages"), dropPrompt -> true, systemPrompt -> defaultSystemPrompt, @@ -152,7 +152,9 @@ class OpenAIPrompt(override val uid: String) extends Transformer val df = dataset.toDF val completion = openAICompletion val promptCol = Functions.template(getPromptTemplate) - + val newErrorCol = df.withDerivativeCol(getErrorCol) + setErrorCol(newErrorCol) + completion.setErrorCol(newErrorCol) completion match { case chatCompletion: OpenAIChatCompletion => if (isSet(responseFormat)) { @@ -173,12 +175,13 @@ class OpenAIPrompt(override val uid: String) extends Transformer .getField("message").getField("content"))) .drop(completionNamed.getOutputCol) + // move error col to back of df + val resultsFinal = results.select(results.columns.filter(_ != getErrorCol).map(col) :+ col(getErrorCol): _*) if (getDropPrompt) { - results.drop(messageColName) + resultsFinal.drop(messageColName) } else { - results + resultsFinal } - case completion: OpenAICompletion => if (isSet(responseFormat)) { throw new IllegalArgumentException("responseFormat is not supported for completion models") @@ -194,10 +197,12 @@ class OpenAIPrompt(override val uid: String) extends Transformer .getField("text"))) .drop(completionNamed.getOutputCol) + // move error col to back of df + val resultsFinal = results.select(results.columns.filter(_ != getErrorCol).map(col) :+ col(getErrorCol): _*) if (getDropPrompt) { - results.drop(promptColName) + resultsFinal.drop(promptColName) } else { - results + resultsFinal } } }, dataset.columns.length)