Skip to content

Commit

Permalink
feat(assistants): add structured response (#391)
Browse files Browse the repository at this point in the history
Co-authored-by: ahmedharis994 <[email protected]>
Co-authored-by: Ahmed Haris Javaid Mirza <[email protected]>
  • Loading branch information
3 people authored Oct 26, 2024
1 parent cf06188 commit c0a3945
Show file tree
Hide file tree
Showing 5 changed files with 214 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,19 @@ import com.aallam.openai.api.assistant.AssistantResponseFormat
import com.aallam.openai.api.assistant.AssistantTool
import com.aallam.openai.api.assistant.assistantRequest
import com.aallam.openai.api.chat.ToolCall
import com.aallam.openai.api.core.RequestOptions
import com.aallam.openai.api.model.ModelId
import com.aallam.openai.api.run.RequiredAction
import com.aallam.openai.api.run.Run
import com.aallam.openai.client.internal.JsonLenient
import kotlin.test.*
import kotlinx.serialization.json.JsonArray
import kotlinx.serialization.json.JsonPrimitive
import kotlinx.serialization.json.buildJsonObject
import kotlinx.serialization.json.put
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertIs
import kotlin.test.assertNull
import kotlin.test.assertTrue

class TestAssistants : TestOpenAI() {

Expand Down Expand Up @@ -144,4 +151,64 @@ class TestAssistants : TestOpenAI() {
val action = decoded.requiredAction as RequiredAction.SubmitToolOutputs
assertIs<ToolCall.Function>(action.toolOutputs.toolCalls.first())
}

@Test
fun jsonSchemaAssistant() = test {
val jsonSchema = AssistantResponseFormat.JSON_SCHEMA(
name = "TestSchema",
description = "A test schema",
schema = buildJsonObject {
put("type", "object")
put("properties", buildJsonObject {
put("name", buildJsonObject {
put("type", "string")
})
})
put("required", JsonArray(listOf(JsonPrimitive("name"))))
put("additionalProperties", false)
},
strict = true
)

val request = assistantRequest {
name = "Schema Assistant"
model = ModelId("gpt-4o-mini")
responseFormat = jsonSchema
}

val assistant = openAI.assistant(
request = request,
)
assertEquals(request.name, assistant.name)
assertEquals(request.model, assistant.model)
assertEquals(request.responseFormat, assistant.responseFormat)

val getAssistant = openAI.assistant(
assistant.id,
)
assertEquals(getAssistant, assistant)

val assistants = openAI.assistants()
assertTrue { assistants.isNotEmpty() }

val updated = assistantRequest {
name = "Updated Schema Assistant"
responseFormat = AssistantResponseFormat.AUTO
}
val updatedAssistant = openAI.assistant(
assistant.id,
updated,
)
assertEquals(updated.name, updatedAssistant.name)
assertEquals(updated.responseFormat, updatedAssistant.responseFormat)

openAI.delete(
updatedAssistant.id,
)

val fileGetAfterDelete = openAI.assistant(
updatedAssistant.id,
)
assertNull(fileGetAfterDelete)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,23 @@ public data class AssistantRequest(
* Specifies the format that the model must output. Compatible with GPT-4o, GPT-4 Turbo, and all GPT-3.5 Turbo
* models since gpt-3.5-turbo-1106.
*
* Setting to [AssistantResponseFormat.JsonObject] enables JSON mode, which guarantees the message the model
* Setting to [AssistantResponseFormat.JSON_SCHEMA] enables Structured Outputs which ensures the model will match your supplied JSON schema.
*
* Structured Outputs ([AssistantResponseFormat.JSON_SCHEMA]) are available in our latest large language models, starting with GPT-4o:
* 1. gpt-4o-mini-2024-07-18 and later
* 2. gpt-4o-2024-08-06 and later
*
* Older models like gpt-4-turbo and earlier may use JSON mode ([AssistantResponseFormat.JSON_OBJECT]) instead.
*
* Setting to [AssistantResponseFormat.JSON_OBJECT] enables JSON mode, which guarantees the message the model
* generates is valid JSON.
*
* important: when using JSON mode, you must also instruct the model to produce JSON yourself via a system or user
* message. Without this, the model may generate an unending stream of whitespace until the generation reaches the
* token limit, resulting in a long-running and seemingly "stuck" request. Also note that the message content may be
* partially cut off if finish_reason="length", which indicates the generation exceeded max_tokens or
* the conversation exceeded the max context length.
*
*/
@SerialName("response_format") val responseFormat: AssistantResponseFormat? = null,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,51 +9,73 @@ import kotlinx.serialization.descriptors.buildClassSerialDescriptor
import kotlinx.serialization.descriptors.element
import kotlinx.serialization.encoding.Decoder
import kotlinx.serialization.encoding.Encoder
import kotlinx.serialization.json.JsonElement
import kotlinx.serialization.json.JsonObject
import kotlinx.serialization.json.JsonObjectBuilder
import kotlinx.serialization.json.JsonPrimitive
import kotlinx.serialization.json.booleanOrNull
import kotlinx.serialization.json.contentOrNull
import kotlinx.serialization.json.jsonObject
import kotlinx.serialization.json.jsonPrimitive

/**
* string: auto is the default value
* Represents the format of the response from the assistant.
*
* object: An object describing the expected output of the model. If json_object only function type tools are allowed to be passed to the Run.
* If text, the model can return text or any value needed.
* type: string Must be one of text or json_object.
* @property type The type of the response format.
* @property jsonSchema The JSON schema associated with the response format, if type is "json_schema" otherwise null.
*/
@BetaOpenAI
@Serializable(with = AssistantResponseFormat.ResponseFormatSerializer::class)
public data class AssistantResponseFormat(
val format: String? = null,
val objectType: AssistantResponseType? = null,
val type: String,
val jsonSchema: JsonSchema? = null
) {

/**
* Represents a JSON schema.
*
* @property name The name of the schema.
* @property description The description of the schema.
* @property schema The actual JSON schema.
* @property strict Indicates if the schema is strict.
*/
@Serializable
public data class AssistantResponseType(
val type: String
public data class JsonSchema(
val name: String,
val description: String? = null,
val schema: JsonObject,
val strict: Boolean? = null
)

public companion object {
public val AUTO: AssistantResponseFormat = AssistantResponseFormat(format = "auto")
public val TEXT: AssistantResponseFormat = AssistantResponseFormat(objectType = AssistantResponseType(type = "text"))
public val JSON_OBJECT: AssistantResponseFormat = AssistantResponseFormat(objectType = AssistantResponseType(type = "json_object"))
public val AUTO: AssistantResponseFormat = AssistantResponseFormat("auto")
public val TEXT: AssistantResponseFormat = AssistantResponseFormat("text")
public val JSON_OBJECT: AssistantResponseFormat = AssistantResponseFormat("json_object")

/**
* Creates an instance of `AssistantResponseFormat` with type `json_schema`.
*
* @param name The name of the schema.
* @param description The description of the schema.
* @param schema The actual JSON schema.
* @param strict Indicates if the schema is strict.
* @return An instance of `AssistantResponseFormat` with the specified JSON schema.
*/
public fun JSON_SCHEMA(
name: String,
description: String? = null,
schema: JsonObject,
strict: Boolean? = null
): AssistantResponseFormat = AssistantResponseFormat(
"json_schema",
JsonSchema(name, description, schema, strict)
)
}


public object ResponseFormatSerializer : KSerializer<AssistantResponseFormat> {
override val descriptor: SerialDescriptor = buildClassSerialDescriptor("AssistantResponseFormat") {
element<String>("format", isOptional = true)
element<AssistantResponseType>("type", isOptional = true)
}

override fun serialize(encoder: Encoder, value: AssistantResponseFormat) {
val jsonEncoder = encoder as? kotlinx.serialization.json.JsonEncoder
?: throw SerializationException("This class can be saved only by Json")

if (value.format != null) {
jsonEncoder.encodeJsonElement(JsonPrimitive(value.format))
} else if (value.objectType != null) {
val jsonElement: JsonElement = JsonObject(mapOf("type" to JsonPrimitive(value.objectType.type)))
jsonEncoder.encodeJsonElement(jsonElement)
}
element<String>("type")
element<JsonSchema>("json_schema", isOptional = true) // Only for "json_schema" type
}

override fun deserialize(decoder: Decoder): AssistantResponseFormat {
Expand All @@ -63,14 +85,63 @@ public data class AssistantResponseFormat(
val jsonElement = jsonDecoder.decodeJsonElement()
return when {
jsonElement is JsonPrimitive && jsonElement.isString -> {
AssistantResponseFormat(format = jsonElement.content)
AssistantResponseFormat(type = jsonElement.content)
}
jsonElement is JsonObject && "type" in jsonElement -> {
val type = jsonElement["type"]!!.jsonPrimitive.content
AssistantResponseFormat(objectType = AssistantResponseType(type))
when (type) {
"json_schema" -> {
val schemaObject = jsonElement["json_schema"]?.jsonObject
val name = schemaObject?.get("name")?.jsonPrimitive?.content ?: ""
val description = schemaObject?.get("description")?.jsonPrimitive?.contentOrNull
val schema = schemaObject?.get("schema")?.jsonObject ?: JsonObject(emptyMap())
val strict = schemaObject?.get("strict")?.jsonPrimitive?.booleanOrNull
AssistantResponseFormat(
type = "json_schema",
jsonSchema = JsonSchema(name = name, description = description, schema = schema, strict = strict)
)
}
"json_object" -> AssistantResponseFormat(type = "json_object")
"auto" -> AssistantResponseFormat(type = "auto")
"text" -> AssistantResponseFormat(type = "text")
else -> throw SerializationException("Unknown response format type: $type")
}
}
else -> throw SerializationException("Unknown response format: $jsonElement")
}
}

override fun serialize(encoder: Encoder, value: AssistantResponseFormat) {
val jsonEncoder = encoder as? kotlinx.serialization.json.JsonEncoder
?: throw SerializationException("This class can be saved only by Json")

val jsonElement = when (value.type) {
"json_schema" -> {
JsonObject(
mapOf(
"type" to JsonPrimitive("json_schema"),
"json_schema" to JsonObject(
mapOf(
"name" to JsonPrimitive(value.jsonSchema?.name ?: ""),
"description" to JsonPrimitive(value.jsonSchema?.description ?: ""),
"schema" to (value.jsonSchema?.schema ?: JsonObject(emptyMap())),
"strict" to JsonPrimitive(value.jsonSchema?.strict ?: false)
)
)
)
)
}
"json_object" -> JsonObject(mapOf("type" to JsonPrimitive("json_object")))
"auto" -> JsonPrimitive("auto")
"text" -> JsonObject(mapOf("type" to JsonPrimitive("text")))
else -> throw SerializationException("Unsupported response format type: ${value.type}")
}
jsonEncoder.encodeJsonElement(jsonElement)
}

}
}

public fun JsonObject.Companion.buildJsonObject(block: JsonObjectBuilder.() -> Unit): JsonObject {
return kotlinx.serialization.json.buildJsonObject(block)
}
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ public data class Run(
@SerialName("usage") public val usage: Usage? = null,

/**
* The Unix timestamp (in seconds) for when the run was completed.
* The sampling temperature used for this run. If not set, defaults to 1.
*/
@SerialName("temperature") val temperature: Double? = null,

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package com.aallam.openai.sample.jvm

import com.aallam.openai.api.BetaOpenAI
import com.aallam.openai.api.assistant.AssistantRequest
import com.aallam.openai.api.assistant.AssistantResponseFormat
import com.aallam.openai.api.assistant.AssistantTool
import com.aallam.openai.api.assistant.Function
import com.aallam.openai.api.chat.ToolCall
Expand All @@ -17,7 +18,10 @@ import com.aallam.openai.api.run.RunRequest
import com.aallam.openai.api.run.ToolOutput
import com.aallam.openai.client.OpenAI
import kotlinx.coroutines.delay
import kotlinx.serialization.json.JsonArray
import kotlinx.serialization.json.JsonPrimitive
import kotlinx.serialization.json.add
import kotlinx.serialization.json.buildJsonObject
import kotlinx.serialization.json.put
import kotlinx.serialization.json.putJsonArray
import kotlinx.serialization.json.putJsonObject
Expand All @@ -29,6 +33,36 @@ suspend fun assistantsFunctions(openAI: OpenAI) {
request = AssistantRequest(
name = "Math Tutor",
instructions = "You are a weather bot. Use the provided functions to answer questions.",
responseFormat = AssistantResponseFormat.JSON_SCHEMA(
name = "math_response",
strict = true,
schema = buildJsonObject {
put("type", "object")
putJsonObject("properties") {
putJsonObject("steps") {
put("type", "array")
putJsonObject("items") {
put("type", "object")
putJsonObject("properties") {
putJsonObject("explanation") {
put("type", "string")
}
putJsonObject("output") {
put("type", "string")
}
}
put("required", JsonArray(listOf(JsonPrimitive("explanation"), JsonPrimitive("output"))))
put("additionalProperties", false)
}
}
putJsonObject("final_answer") {
put("type", "string")
}
}
put("additionalProperties", false)
put("required", JsonArray(listOf(JsonPrimitive("steps"), JsonPrimitive("final_answer"))))
},
),
tools = listOf(
AssistantTool.FunctionTool(
function = Function(
Expand Down Expand Up @@ -74,7 +108,7 @@ suspend fun assistantsFunctions(openAI: OpenAI) {
)
)
),
model = ModelId("gpt-4-1106-preview")
model = ModelId("gpt-4o-mini")
)
)

Expand Down

0 comments on commit c0a3945

Please sign in to comment.