diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala index db26c44faa..f0ff261d56 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala @@ -15,8 +15,8 @@ import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap, ParamValidators import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.{Column, DataFrame, Dataset, Row, functions => F, types => T} import org.apache.spark.sql.catalyst.encoders.RowEncoder -import org.apache.spark.sql.functions.udf -import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.types.{DataType, StructField, StructType} import scala.collection.JavaConverters._ @@ -172,13 +172,12 @@ class OpenAIPrompt(override val uid: String) extends Transformer getParser.parse(F.element_at(F.col(completionNamed.getOutputCol).getField("choices"), 1) .getField("message").getField("content"))) .drop(completionNamed.getOutputCol) - + val resultsFinal = results.select(results.columns.filter(_ != getErrorCol).map(col) :+ col(getErrorCol): _*) if (getDropPrompt) { - results.drop(messageColName) + resultsFinal.drop(messageColName) } else { - results + resultsFinal } - case completion: OpenAICompletion => if (isSet(responseFormat)) { throw new IllegalArgumentException("responseFormat is not supported for completion models") @@ -186,18 +185,17 @@ class OpenAIPrompt(override val uid: String) extends Transformer val promptColName = df.withDerivativeCol("prompt") val dfTemplated = df.withColumn(promptColName, promptCol) val completionNamed = completion.setPromptCol(promptColName) - // run completion val results = completionNamed .transform(dfTemplated) .withColumn(getOutputCol, getParser.parse(F.element_at(F.col(completionNamed.getOutputCol).getField("choices"), 1) .getField("text"))) .drop(completionNamed.getOutputCol) - + val resultsFinal = results.select(results.columns.filter(_ != getErrorCol).map(col) :+ col(getErrorCol): _*) if (getDropPrompt) { - results.drop(promptColName) + resultsFinal.drop(promptColName) } else { - results + resultsFinal } } }, dataset.columns.length) @@ -238,7 +236,7 @@ class OpenAIPrompt(override val uid: String) extends Transformer } // apply all parameters extractParamMap().toSeq - .filter(p => !localParamNames.contains(p.param.name)) + .filter(p => !localParamNames.contains(p.param.name) && completion.hasParam(p.param.name)) .foreach(p => completion.set(completion.getParam(p.param.name), p.value)) completion @@ -267,7 +265,7 @@ class OpenAIPrompt(override val uid: String) extends Transformer } override def transformSchema(schema: StructType): StructType = { - openAICompletion match { + val transformedSchema = openAICompletion match { case chatCompletion: OpenAIChatCompletion => chatCompletion .transformSchema(schema.add(getMessagesCol, StructType(Seq()))) @@ -277,6 +275,12 @@ class OpenAIPrompt(override val uid: String) extends Transformer .transformSchema(schema) .add(getPostProcessing, getParser.outputSchema) } + + // Move error column to back + val errorFieldOpt: Option[StructField] = transformedSchema.fields.find(_.name == getErrorCol) + val fieldsWithoutError: Array[StructField] = transformedSchema.fields.filterNot(_.name == getErrorCol) + val reorderedFields = Array.concat(fieldsWithoutError, errorFieldOpt.toArray) + StructType(reorderedFields) } }