Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: OpenAIPrompt bug fixes #2334

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap, ParamValidators
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.{Column, DataFrame, Dataset, Row, functions => F, types => T}
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{DataType, StructField, StructType}

import scala.collection.JavaConverters._

Expand Down Expand Up @@ -172,32 +172,30 @@ class OpenAIPrompt(override val uid: String) extends Transformer
getParser.parse(F.element_at(F.col(completionNamed.getOutputCol).getField("choices"), 1)
.getField("message").getField("content")))
.drop(completionNamed.getOutputCol)

val resultsFinal = results.select(results.columns.filter(_ != getErrorCol).map(col) :+ col(getErrorCol): _*)
if (getDropPrompt) {
results.drop(messageColName)
resultsFinal.drop(messageColName)
} else {
results
resultsFinal
}

case completion: OpenAICompletion =>
if (isSet(responseFormat)) {
throw new IllegalArgumentException("responseFormat is not supported for completion models")
}
val promptColName = df.withDerivativeCol("prompt")
val dfTemplated = df.withColumn(promptColName, promptCol)
val completionNamed = completion.setPromptCol(promptColName)
// run completion
val results = completionNamed
.transform(dfTemplated)
.withColumn(getOutputCol,
getParser.parse(F.element_at(F.col(completionNamed.getOutputCol).getField("choices"), 1)
.getField("text")))
.drop(completionNamed.getOutputCol)

val resultsFinal = results.select(results.columns.filter(_ != getErrorCol).map(col) :+ col(getErrorCol): _*)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this necessary, do error colum,ns show up before other columns?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now the Error column appears before the output column - we wanted it at the end so I added this

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this change reflected in transform schema?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After trying modifying SimpleHTTPTransformer, I think the withColumn(getOutputCol on line 192 is moving the output column to the back

Copy link
Contributor Author

@sss04 sss04 Jan 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did some surface level research and it seems if we're rearranging columns, transformSchema doesn't need a change? if that's true, we're good with the code as is

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it unfortunately will as it is ordered under the hood, you can ovverride this func too to do the switch

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed transformSchema to move the Error column to the back

if (getDropPrompt) {
results.drop(promptColName)
resultsFinal.drop(promptColName)
} else {
results
resultsFinal
}
}
}, dataset.columns.length)
Expand Down Expand Up @@ -238,7 +236,7 @@ class OpenAIPrompt(override val uid: String) extends Transformer
}
// apply all parameters
extractParamMap().toSeq
.filter(p => !localParamNames.contains(p.param.name))
.filter(p => !localParamNames.contains(p.param.name) && completion.hasParam(p.param.name))
.foreach(p => completion.set(completion.getParam(p.param.name), p.value))

completion
Expand Down Expand Up @@ -267,7 +265,7 @@ class OpenAIPrompt(override val uid: String) extends Transformer
}

override def transformSchema(schema: StructType): StructType = {
openAICompletion match {
val transformedSchema = openAICompletion match {
case chatCompletion: OpenAIChatCompletion =>
chatCompletion
.transformSchema(schema.add(getMessagesCol, StructType(Seq())))
Expand All @@ -277,6 +275,12 @@ class OpenAIPrompt(override val uid: String) extends Transformer
.transformSchema(schema)
.add(getPostProcessing, getParser.outputSchema)
}

// Move error column to back
val errorFieldOpt: Option[StructField] = transformedSchema.fields.find(_.name == getErrorCol)
val fieldsWithoutError: Array[StructField] = transformedSchema.fields.filterNot(_.name == getErrorCol)
val reorderedFields = Array.concat(fieldsWithoutError, errorFieldOpt.toArray)
StructType(reorderedFields)
}
}

Expand Down
Loading