Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Adding code for sending parameter "response_format" as request payload #2317

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ trait HasOpenAITextParams extends HasOpenAISharedParams {

// list of shared text parameters. In method getOptionalParams, we will iterate over these parameters
// to compute the optional parameters. Since this list never changes, we can create it once and reuse it.
private val sharedTextParams = Seq(
private[openai] val sharedTextParams: Seq[ServiceParam[_]] = Seq(
maxTokens,
temperature,
topP,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package com.microsoft.azure.synapse.ml.services.openai

import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging}
import com.microsoft.azure.synapse.ml.param.AnyJsonFormat.anyFormat
import com.microsoft.azure.synapse.ml.param.ServiceParam
import com.microsoft.azure.synapse.ml.services.{HasCognitiveServiceInput, HasInternalJsonOutputParser}
import org.apache.http.entity.{AbstractHttpEntity, ContentType, StringEntity}
import org.apache.spark.ml.ComplexParamsReadable
Expand All @@ -18,8 +19,84 @@ import scala.language.existentials

object OpenAIChatCompletion extends ComplexParamsReadable[OpenAIChatCompletion]

object OpenAIResponseFormat extends Enumeration {
case class ResponseFormat(name: String, prompt: String) extends super.Val(name)
val TEXT: ResponseFormat = ResponseFormat("text", "Output must be in text format")
val JSON: ResponseFormat = ResponseFormat("json_object", "Output must be in JSON format")
}


trait HasOpenAITextParamsExtended extends HasOpenAITextParams {
val responseFormat: ServiceParam[Map[String, String]] = new ServiceParam[Map[String, String]](
this,
"responseFormat",
"Response format for the completion. Can be 'json_object' or 'text'.",
isRequired = false) {
override val payloadName: String = "response_format"
}

def getResponseFormat: Map[String, String] = getScalarParam(responseFormat)

def setResponseFormat(value: Map[String, String]): this.type = {
if (!OpenAIResponseFormat.values.map(_.asInstanceOf[OpenAIResponseFormat.ResponseFormat].name)
.contains(value("type"))) {
throw new IllegalArgumentException("Response format must be 'text' or 'json_object'")
}
setScalarParam(responseFormat, value)
}

def setResponseFormat(value: String): this.type = {
if (value.isEmpty) {
this
} else {
val normalizedValue = value.toLowerCase match {
case "json" => "json_object"
case other => other
}
// Validate the normalized value using the OpenAIResponseFormat enum
if (!OpenAIResponseFormat.values
.map(_.asInstanceOf[OpenAIResponseFormat.ResponseFormat].name)
.contains(normalizedValue)) {
throw new IllegalArgumentException("Response format must be valid for OpenAI API. " +
"Currently supported formats are " + OpenAIResponseFormat.values
.map(_.asInstanceOf[OpenAIResponseFormat.ResponseFormat].name)
.mkString(", "))
}

setScalarParam(responseFormat, Map("type" -> normalizedValue))
}
}

def setResponseFormat(value: OpenAIResponseFormat.ResponseFormat): this.type = {
// this method should throw an excption if the openAiCompletion is not a ChatCompletion
this.setResponseFormat(value.name)
}

def getResponseFormatCol: String = getVectorParam(responseFormat)

def setResponseFormatCol(value: String): this.type = setVectorParam(responseFormat, value)


// Recreating the sharedTextParams sequence to include additional parameter responseFormat
override private[openai] val sharedTextParams: Seq[ServiceParam[_]] = Seq(
maxTokens,
temperature,
topP,
user,
n,
echo,
stop,
cacheLevel,
presencePenalty,
frequencyPenalty,
bestOf,
logProbs,
responseFormat // Additional parameter
)
}

class OpenAIChatCompletion(override val uid: String) extends OpenAIServicesBase(uid)
with HasOpenAITextParams with HasMessagesInput with HasCognitiveServiceInput
with HasOpenAITextParamsExtended with HasMessagesInput with HasCognitiveServiceInput
with HasInternalJsonOutputParser with SynapseMLLogging {
logClass(FeatureNames.AiServices.OpenAI)

Expand Down Expand Up @@ -55,12 +132,23 @@ class OpenAIChatCompletion(override val uid: String) extends OpenAIServicesBase(
override def responseDataType: DataType = ChatCompletionResponse.schema

private[this] def getStringEntity(messages: Seq[Row], optionalParams: Map[String, Any]): StringEntity = {
val mappedMessages: Seq[Map[String, String]] = messages.map { m =>
Seq("role", "content", "name").map(n =>
n -> Option(m.getAs[String](n))
).toMap.filter(_._2.isDefined).mapValues(_.get)
// Convert each message row to a map of non-null values
val mappedMessages: Seq[Map[String, String]] = messages.map { messageRow =>
Seq("role", "content", "name").map { fieldName =>
fieldName -> Option(messageRow.getAs[String](fieldName))
}.toMap.filter(_._2.isDefined).mapValues(_.get)
}

// Check if the response format is JSON and add a system message if needed
val updatedMessages = if (optionalParams.get("response_format")
.exists(_.asInstanceOf[Map[String, String]]("type") == "json_object")) {
mappedMessages :+ Map("role" -> "system", "content" -> OpenAIResponseFormat.JSON.prompt)
} else {
mappedMessages
}
val fullPayload = optionalParams.updated("messages", mappedMessages)

// Update the optional parameters with the messages
val fullPayload = optionalParams.updated("messages", updatedMessages)
new StringEntity(fullPayload.toJson.compactPrint, ContentType.APPLICATION_JSON)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import scala.collection.JavaConverters._
object OpenAIPrompt extends ComplexParamsReadable[OpenAIPrompt]

class OpenAIPrompt(override val uid: String) extends Transformer
with HasOpenAITextParams with HasMessagesInput
with HasOpenAITextParamsExtended with HasMessagesInput
with HasErrorCol with HasOutputCol
with HasURL with HasCustomCogServiceDomain with ConcurrencyParams
with HasSubscriptionKey with HasAADToken with HasCustomAuthHeader
Expand Down Expand Up @@ -60,7 +60,37 @@ class OpenAIPrompt(override val uid: String) extends Transformer

def getPostProcessingOptions: Map[String, String] = $(postProcessingOptions)

def setPostProcessingOptions(value: Map[String, String]): this.type = set(postProcessingOptions, value)
def setPostProcessingOptions(value: Map[String, String]): this.type = {
// Helper method to validate that regex options contain the required "regexGroup" key
def validateRegexOptions(options: Map[String, String]): Unit = {
require(options.contains("regexGroup"), "regexGroup must be specified with regex")
}

// Helper method to set or validate the postProcessing parameter
def setOrValidatePostProcessing(expected: String): Unit = {
if (isSet(postProcessing)) {
require(getPostProcessing == expected, s"postProcessing must be '$expected'")
} else {
set(postProcessing, expected)
}
}

// Match on the keys in the provided value map to set the appropriate post-processing option
value match {
case v if v.contains("delimiter") =>
setOrValidatePostProcessing("csv")
case v if v.contains("jsonSchema") =>
setOrValidatePostProcessing("json")
case v if v.contains("regex") =>
validateRegexOptions(v)
setOrValidatePostProcessing("regex")
case _ =>
throw new IllegalArgumentException("Invalid post processing options")
}

// Set the postProcessingOptions parameter with the provided value map
set(postProcessingOptions, value)
}

def setPostProcessingOptions(v: java.util.HashMap[String, String]): this.type =
set(postProcessingOptions, v.asScala.toMap)
Expand Down Expand Up @@ -99,16 +129,14 @@ class OpenAIPrompt(override val uid: String) extends Transformer

private val localParamNames = Seq(
"promptTemplate", "outputCol", "postProcessing", "postProcessingOptions", "dropPrompt", "dropMessages",
"systemPrompt")
"systemPrompt", "responseFormat")

private def addRAIErrors(df: DataFrame, errorCol: String, outputCol: String): DataFrame = {
val openAIResultFromRow = ChatCompletionResponse.makeFromRowConverter
df.map({ row =>
val originalOutput = Option(row.getAs[Row](outputCol))
.map({ row => openAIResultFromRow(row).choices.head })
val isFiltered = originalOutput
.map(output => Option(output.message.content).isEmpty)
.getOrElse(false)
val isFiltered = originalOutput.exists(output => Option(output.message.content).isEmpty)

if (isFiltered) {
val updatedRowSeq = row.toSeq.updated(
Expand All @@ -127,7 +155,6 @@ class OpenAIPrompt(override val uid: String) extends Transformer

logTransform[DataFrame]({
val df = dataset.toDF

val completion = openAICompletion
val promptCol = Functions.template(getPromptTemplate)
val createMessagesUDF = udf((userMessage: String) => {
Expand All @@ -138,18 +165,19 @@ class OpenAIPrompt(override val uid: String) extends Transformer
})
completion match {
case chatCompletion: OpenAIChatCompletion =>
if (isSet(responseFormat)) {
chatCompletion.setResponseFormat(getResponseFormat)
}
val messageColName = getMessagesCol
val dfTemplated = df.withColumn(messageColName, createMessagesUDF(promptCol))
val completionNamed = chatCompletion.setMessagesCol(messageColName)

val transformed = addRAIErrors(
completionNamed.transform(dfTemplated), chatCompletion.getErrorCol, chatCompletion.getOutputCol)

val results = transformed
.withColumn(getOutputCol,
getParser.parse(F.element_at(F.col(completionNamed.getOutputCol).getField("choices"), 1)
.getField("message").getField("content")))
.drop(completionNamed.getOutputCol)
.getField("message").getField("content"))).drop(completionNamed.getOutputCol)

if (getDropPrompt) {
results.drop(messageColName)
Expand All @@ -158,6 +186,9 @@ class OpenAIPrompt(override val uid: String) extends Transformer
}

case completion: OpenAICompletion =>
if (isSet(responseFormat)) {
throw new IllegalArgumentException("responseFormat is not supported for OpenAICompletion")
}
val promptColName = df.withDerivativeCol("prompt")
val dfTemplated = df.withColumn(promptColName, promptCol)
val completionNamed = completion.setPromptCol(promptColName)
Expand Down Expand Up @@ -215,8 +246,8 @@ class OpenAIPrompt(override val uid: String) extends Transformer

getPostProcessing.toLowerCase match {
case "csv" => new DelimiterParser(opts.getOrElse("delimiter", ","))
case "json" => new JsonParser(opts.get("jsonSchema").get, Map.empty)
case "regex" => new RegexParser(opts.get("regex").get, opts.get("regexGroup").get.toInt)
case "json" => new JsonParser(opts("jsonSchema"), Map.empty)
case "regex" => new RegexParser(opts("regex"), opts("regexGroup").toInt)
case "" => new PassThroughParser()
case _ => throw new IllegalArgumentException(s"Unsupported postProcessing type: '$getPostProcessing'")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import com.microsoft.azure.synapse.ml.core.test.fuzzing.{TestObject, Transformer
import org.apache.spark.ml.util.MLReadable
import org.apache.spark.sql.{DataFrame, Row}
import org.scalactic.Equality
import org.scalatest.matchers.must.Matchers.{an, be}
import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper

class OpenAIChatCompletionSuite extends TransformerFuzzing[OpenAIChatCompletion] with OpenAIAPIKey with Flaky {

Expand Down Expand Up @@ -174,6 +176,45 @@ class OpenAIChatCompletionSuite extends TransformerFuzzing[OpenAIChatCompletion]
testCompletion(customEndpointCompletion, goodDf)
}

test("setResponseFormat should set the response format correctly") {
completion.setResponseFormat("text")
completion.getResponseFormat shouldEqual Map("type" -> "text")

completion.setResponseFormat("tExT")
completion.getResponseFormat shouldEqual Map("type" -> "text")

completion.setResponseFormat("json")
completion.getResponseFormat shouldEqual Map("type" -> "json_object")

completion.setResponseFormat("JSON")
completion.getResponseFormat shouldEqual Map("type" -> "json_object")

completion.setResponseFormat("json_object")
completion.getResponseFormat shouldEqual Map("type" -> "json_object")

completion.setResponseFormat("Json_ObjeCt")
completion.getResponseFormat shouldEqual Map("type" -> "json_object")
}

test("setResponseFormat should throw an exception for invalid response format") {
an[IllegalArgumentException] should be thrownBy {
completion.setResponseFormat("invalid_format")
}
}

test("setResponseFormat with ResponseFormat should set the response format correctly") {
completion.setResponseFormat(OpenAIResponseFormat.TEXT)
completion.getResponseFormat shouldEqual Map("type" -> "text")

completion.setResponseFormat(OpenAIResponseFormat.JSON)
completion.getResponseFormat shouldEqual Map("type" -> "json_object")
}

test("setResponseFormatCol should set the response format column correctly") {
completion.setResponseFormatCol("response_format_col")
completion.getResponseFormatCol shouldEqual "response_format_col"
}

def testCompletion(completion: OpenAIChatCompletion, df: DataFrame, requiredLength: Int = 10): Unit = {
val fromRow = ChatCompletionResponse.makeFromRowConverter
completion.transform(df).collect().foreach(r =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ trait OpenAIAPIKey {
lazy val deploymentName: String = "gpt-35-turbo"
lazy val modelName: String = "gpt-35-turbo"
lazy val deploymentNameGpt4: String = "gpt-4"
lazy val deploymentNameDavinci3: String = "text-davinci-003"
lazy val deploymentNameGpt4o: String = "gpt-4o"
lazy val modelNameGpt4: String = "gpt-4"
}

Expand Down
Loading
Loading