Skip to content

Commit

Permalink
Rename error column and move to back of df
Browse files Browse the repository at this point in the history
  • Loading branch information
sss04 committed Jan 6, 2025
1 parent 68ad5ee commit 2cf4f64
Showing 1 changed file with 13 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap, ParamValidators
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.{Column, DataFrame, Dataset, Row, functions => F, types => T}
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{DataType, StructType}

import scala.collection.JavaConverters._
Expand Down Expand Up @@ -111,7 +111,7 @@ class OpenAIPrompt(override val uid: String) extends Transformer
postProcessing -> "",
postProcessingOptions -> Map.empty,
outputCol -> (this.uid + "_output"),
errorCol -> (this.uid + "_error"),
errorCol -> "gen_error",
messagesCol -> (this.uid + "_messages"),
dropPrompt -> true,
systemPrompt -> defaultSystemPrompt,
Expand Down Expand Up @@ -152,7 +152,9 @@ class OpenAIPrompt(override val uid: String) extends Transformer
val df = dataset.toDF
val completion = openAICompletion
val promptCol = Functions.template(getPromptTemplate)

val newErrorCol = df.withDerivativeCol(getErrorCol)
setErrorCol(newErrorCol)
completion.setErrorCol(newErrorCol)
completion match {
case chatCompletion: OpenAIChatCompletion =>
if (isSet(responseFormat)) {
Expand All @@ -173,12 +175,13 @@ class OpenAIPrompt(override val uid: String) extends Transformer
.getField("message").getField("content")))
.drop(completionNamed.getOutputCol)

// move error col to back of df
val resultsFinal = results.select(results.columns.filter(_ != getErrorCol).map(col) :+ col(getErrorCol): _*)
if (getDropPrompt) {
results.drop(messageColName)
resultsFinal.drop(messageColName)
} else {
results
resultsFinal
}

case completion: OpenAICompletion =>
if (isSet(responseFormat)) {
throw new IllegalArgumentException("responseFormat is not supported for completion models")
Expand All @@ -194,10 +197,12 @@ class OpenAIPrompt(override val uid: String) extends Transformer
.getField("text")))
.drop(completionNamed.getOutputCol)

// move error col to back of df
val resultsFinal = results.select(results.columns.filter(_ != getErrorCol).map(col) :+ col(getErrorCol): _*)
if (getDropPrompt) {
results.drop(promptColName)
resultsFinal.drop(promptColName)
} else {
results
resultsFinal
}
}
}, dataset.columns.length)
Expand Down

0 comments on commit 2cf4f64

Please sign in to comment.