From 3a8c11a370a8579388be3baa5410aa4e9d79e59f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fede=20Fern=C3=A1ndez?= <720923+fedefernandez@users.noreply.github.com> Date: Mon, 6 Nov 2023 17:27:19 +0100 Subject: [PATCH] MLflow model adapter (#517) * Updates adapters * Completes mlflow model adapter * Restores some unrelated changes with this PR --- .../xef/server/http/client/ModelUriAdapter.kt | 63 +++--- .../http/client/mlflow/MLflowModelAdapter.kt | 113 ++++++++++ .../mlflow/MLflowModelAdapterBuilder.kt | 14 ++ .../server/http/client/mlflow/OpenAIModels.kt | 200 ++++++++++++++++++ .../xef/server/http/routes/AIRoutes.kt | 19 +- 5 files changed, 365 insertions(+), 44 deletions(-) create mode 100644 server/src/main/kotlin/com/xebia/functional/xef/server/http/client/mlflow/MLflowModelAdapter.kt create mode 100644 server/src/main/kotlin/com/xebia/functional/xef/server/http/client/mlflow/MLflowModelAdapterBuilder.kt create mode 100644 server/src/main/kotlin/com/xebia/functional/xef/server/http/client/mlflow/OpenAIModels.kt diff --git a/server/src/main/kotlin/com/xebia/functional/xef/server/http/client/ModelUriAdapter.kt b/server/src/main/kotlin/com/xebia/functional/xef/server/http/client/ModelUriAdapter.kt index 628560be9..9d256d74f 100644 --- a/server/src/main/kotlin/com/xebia/functional/xef/server/http/client/ModelUriAdapter.kt +++ b/server/src/main/kotlin/com/xebia/functional/xef/server/http/client/ModelUriAdapter.kt @@ -7,7 +7,6 @@ import io.ktor.client.request.* import io.ktor.http.* import io.ktor.http.content.* import io.ktor.util.* -import io.ktor.util.pipeline.* import kotlinx.serialization.json.Json import kotlinx.serialization.json.JsonObject import kotlinx.serialization.json.jsonPrimitive @@ -29,49 +28,39 @@ internal constructor(private val urlMap: Map ModelUriAdapterBuilder().apply(block).build() override fun install(plugin: ModelUriAdapter, scope: HttpClient) { - installModelAuthAdapter(plugin, scope) + installModelUriAdapter(plugin, scope) } - private fun readModelFromRequest(originalRequest: OutgoingContent.ByteArrayContent?): String? { - val requestBody = originalRequest?.bytes()?.toString(Charsets.UTF_8) - val json = requestBody?.let { Json.decodeFromString(it) } - return json?.get("model")?.jsonPrimitive?.content + private fun readModelFromRequest(originalRequest: ByteArray): String? { + val requestBody = originalRequest.toString(Charsets.UTF_8) + val json = Json.decodeFromString(requestBody) + return json["model"]?.jsonPrimitive?.content } - private fun installModelAuthAdapter(plugin: ModelUriAdapter, scope: HttpClient) { - val adaptAuthRequestPhase = PipelinePhase("ModelAuthAdaptRequest") - scope.sendPipeline.insertPhaseAfter(HttpSendPipeline.State, adaptAuthRequestPhase) - scope.sendPipeline.intercept(adaptAuthRequestPhase) { content -> + private fun installModelUriAdapter(plugin: ModelUriAdapter, scope: HttpClient) { + scope.requestPipeline.intercept(HttpRequestPipeline.Transform) { content -> val originalPath = OpenAIPathType.from(context.url.encodedPath) ?: return@intercept - if (plugin.isDefined(originalPath)) { - val originalRequest = content as? OutgoingContent.ByteArrayContent - if (originalRequest == null) { - plugin.logger.warn { - """ - |Can't adapt the model auth. - |The body type is: ${content::class}, with Content-Type: ${context.contentType()}. - | - |If you expect serialized body, please check that you have installed the corresponding - |plugin(like `ContentNegotiation`) and set `Content-Type` header.""" - .trimMargin() - } - return@intercept + if (!plugin.isDefined(originalPath)) return@intercept + val model = + when (content) { + is OutgoingContent.ByteArrayContent -> readModelFromRequest(content.bytes()) + is ByteArray -> readModelFromRequest(content) + else -> return@intercept } - val model = readModelFromRequest(originalRequest) - val newURL = model?.let { plugin.findPath(originalPath, it) } - if (newURL == null) { - plugin.logger.info { - "Model auth didn't found a new url for path $originalPath and model $model" - } - } else { - val baseBuilder = URLBuilder(newURL).build() - context.url.set( - scheme = baseBuilder.protocol.name, - host = baseBuilder.host, - port = baseBuilder.port, - path = baseBuilder.encodedPath - ) + val newURL = model?.let { plugin.findPath(originalPath, it) } + if (newURL == null) + plugin.logger.info { "New url for path $originalPath and model $model not found" } + else { + plugin.logger.info { + "Intercepting request for path $originalPath and model $model to $newURL" } + val baseBuilder = URLBuilder(newURL).build() + context.url.set( + scheme = baseBuilder.protocol.name, + host = baseBuilder.host, + port = baseBuilder.port, + path = baseBuilder.encodedPath + ) } } } diff --git a/server/src/main/kotlin/com/xebia/functional/xef/server/http/client/mlflow/MLflowModelAdapter.kt b/server/src/main/kotlin/com/xebia/functional/xef/server/http/client/mlflow/MLflowModelAdapter.kt new file mode 100644 index 000000000..17ad79ec0 --- /dev/null +++ b/server/src/main/kotlin/com/xebia/functional/xef/server/http/client/mlflow/MLflowModelAdapter.kt @@ -0,0 +1,113 @@ +package com.xebia.functional.xef.server.http.client.mlflow + +import com.xebia.functional.xef.server.http.client.OpenAIPathType +import io.github.oshai.kotlinlogging.KotlinLogging +import io.ktor.client.* +import io.ktor.client.call.* +import io.ktor.client.plugins.* +import io.ktor.client.request.* +import io.ktor.client.statement.* +import io.ktor.http.* +import io.ktor.http.content.* +import io.ktor.util.* +import io.ktor.util.pipeline.* +import io.ktor.util.reflect.* +import io.ktor.utils.io.core.* +import kotlinx.serialization.ExperimentalSerializationApi +import kotlinx.serialization.encodeToString +import kotlinx.serialization.json.Json + +class MLflowModelAdapter +internal constructor(private val mappedRequests: Map) { + + val logger = KotlinLogging.logger {} + + fun mappedType(path: String): OpenAIPathType? = mappedRequests[path] + + companion object : HttpClientPlugin { + + @OptIn(ExperimentalSerializationApi::class) + private val json = Json { + explicitNulls = false + ignoreUnknownKeys = true + } + + override val key: AttributeKey = AttributeKey("MLflowAdapter") + + override fun prepare(block: MLflowModelAdapterBuilder.() -> Unit): MLflowModelAdapter = + MLflowModelAdapterBuilder().apply(block).build() + + override fun install(plugin: MLflowModelAdapter, scope: HttpClient) { + installMLflowRequestAdapter(plugin, scope) + installMLflowResponseAdapter(plugin, scope) + } + + private suspend inline fun PipelineContext + .update(originalRequest: ByteArray, toMLflow: R1.() -> R2) { + val stringRequestBody = originalRequest.toString(Charsets.UTF_8) + val requestData = json.decodeFromString(stringRequestBody) + val newRequest = + TextContent(json.encodeToString(requestData.toMLflow()), ContentType.Application.Json) + proceedWith(newRequest) + } + + private suspend inline fun PipelineContext< + HttpResponseContainer, HttpClientCall + > + .update(typeInfo: TypeInfo, contentResponse: ByteReadPacket, toXef: R1.() -> R2) { + val stringResponseBody = contentResponse.readText() + val responseData = json.decodeFromString(stringResponseBody) + val newResponse = + ByteReadPacket(json.encodeToString(responseData.toXef()).toByteArray(Charsets.UTF_8)) + val response = HttpResponseContainer(typeInfo, newResponse) + proceedWith(response) + } + + private fun installMLflowRequestAdapter(plugin: MLflowModelAdapter, scope: HttpClient) { + scope.requestPipeline.intercept(HttpRequestPipeline.Transform) { content -> + val originalPath = plugin.mappedType(context.url.buildString()) ?: return@intercept + val originalRequest = + when (content) { + is OutgoingContent.ByteArrayContent -> content.bytes() + is ByteArray -> content + else -> return@intercept + } + when (originalPath) { + OpenAIPathType.CHAT -> { + plugin.logger.info { "Intercepting chat request for path $originalPath" } + update(originalRequest, XefChatRequest::toMLflow) + } + OpenAIPathType.EMBEDDINGS -> { + plugin.logger.info { "Intercepting embeddings request for path $originalPath" } + update(originalRequest, XefEmbeddingsRequest::toMLflow) + } + else -> { + plugin.logger.warn { "$originalPath not supported" } + return@intercept + } + } + } + } + + private fun installMLflowResponseAdapter(plugin: MLflowModelAdapter, scope: HttpClient) { + scope.responsePipeline.intercept(HttpResponsePipeline.Transform) { content -> + val originalPath = plugin.mappedType(context.request.url.toString()) ?: return@intercept + val contentResponse = content.response as? ByteReadPacket ?: return@intercept + when (originalPath) { + OpenAIPathType.CHAT -> { + plugin.logger.info { "Intercepting chat response for path $originalPath" } + update(content.expectedType, contentResponse, MLflowChatResponse::toXef) + } + OpenAIPathType.EMBEDDINGS -> { + plugin.logger.info { "Intercepting embeddings response for path $originalPath" } + update(content.expectedType, contentResponse, MLflowEmbeddingsResponse::toXef) + } + else -> { + plugin.logger.warn { "$originalPath not supported" } + return@intercept + } + } + } + } + } +} diff --git a/server/src/main/kotlin/com/xebia/functional/xef/server/http/client/mlflow/MLflowModelAdapterBuilder.kt b/server/src/main/kotlin/com/xebia/functional/xef/server/http/client/mlflow/MLflowModelAdapterBuilder.kt new file mode 100644 index 000000000..62a49dee3 --- /dev/null +++ b/server/src/main/kotlin/com/xebia/functional/xef/server/http/client/mlflow/MLflowModelAdapterBuilder.kt @@ -0,0 +1,14 @@ +package com.xebia.functional.xef.server.http.client.mlflow + +import com.xebia.functional.xef.server.http.client.OpenAIPathType + +class MLflowModelAdapterBuilder { + + private var pathTypeMap: Map = LinkedHashMap() + + fun addToPath(path: String, pathType: OpenAIPathType) { + this.pathTypeMap += mapOf(path to pathType) + } + + internal fun build(): MLflowModelAdapter = MLflowModelAdapter(pathTypeMap) +} diff --git a/server/src/main/kotlin/com/xebia/functional/xef/server/http/client/mlflow/OpenAIModels.kt b/server/src/main/kotlin/com/xebia/functional/xef/server/http/client/mlflow/OpenAIModels.kt new file mode 100644 index 000000000..3902ba800 --- /dev/null +++ b/server/src/main/kotlin/com/xebia/functional/xef/server/http/client/mlflow/OpenAIModels.kt @@ -0,0 +1,200 @@ +package com.xebia.functional.xef.server.http.client.mlflow + +import io.ktor.util.date.* +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable +import kotlinx.serialization.builtins.ListSerializer +import kotlinx.serialization.builtins.serializer +import kotlinx.serialization.json.JsonArray +import kotlinx.serialization.json.JsonElement +import kotlinx.serialization.json.JsonTransformingSerializer +import kotlinx.serialization.json.jsonArray +import kotlinx.uuid.UUID +import kotlinx.uuid.generateUUID + +@Serializable +data class XefChatRequest( + val model: String, + @Serializable(with = XefChatMessageSerializable::class) val messages: List, + val temperature: Double? = null, + val n: Int? = null, + val stop: List? = null, + @SerialName("max_tokens") val maxTokens: Int? = null, + val stream: Boolean? = false +) + +@Serializable data class XefChatMessage(val role: String, val content: String) + +@Serializable +data class XefEmbeddingsRequest( + val model: String, + @Serializable(with = StringArraySerializable::class) val input: List, + val encodingFormat: XefEncodingFormat? = null +) + +@Serializable +enum class XefEncodingFormat { + @SerialName("float") FLOAT, + @SerialName("base64") BASE64 +} + +@Serializable +data class XefChatResponse( + val id: String, + @SerialName("object") val objectModel: String, + val created: Long, + val model: String, + val choices: List, + val usage: XefResponseUsage +) + +@Serializable +data class XefResponseChoice( + val index: Int, + val message: XefResponseMessage, + @SerialName("finish_reason") val finishReason: String? +) + +@Serializable data class XefResponseMessage(val role: String, val content: String) + +@Serializable +data class XefResponseUsage( + @SerialName("prompt_tokens") val promptTokens: Int? = null, + @SerialName("completion_tokens") val completionTokens: Int? = null, + @SerialName("total_tokens") val totalTokens: Int? = null +) + +@Serializable +data class XefEmbeddingResponse( + @SerialName("object") val objectModel: String, + val data: List, + val model: String, + val usage: XefResponseUsage +) + +@Serializable +data class XefEmbedding( + @SerialName("object") val objectModel: String, + val index: Int, + val embedding: List +) + +@Serializable +data class MLflowChatRequest( + @Serializable(with = MLflowChatMessageSerializable::class) val messages: List, + val temperature: Double? = null, + @SerialName("candidate_count") val candidateCount: Int? = null, + val stop: List? = null, + @SerialName("max_tokens") val maxTokens: Int? = null +) + +@Serializable +data class MLflowChatResponse( + val candidates: List, + val metadata: MLflowResponseMetadata +) + +@Serializable +data class MLflowChatCandidate( + val message: MLflowChatMessage, + val metadata: MLflowCandidateMetadata +) + +@Serializable data class MLflowChatMessage(val role: String, val content: String) + +@Serializable +data class MLflowCandidateMetadata(@SerialName("finish_reason") val finishReason: String?) + +@Serializable +data class MLflowEmbeddingsRequest( + @Serializable(with = StringArraySerializable::class) val text: List +) + +@Serializable +data class MLflowEmbeddingsResponse( + val embeddings: List>, + val metadata: MLflowResponseMetadata +) + +@Serializable +data class MLflowResponseMetadata( + val model: String, + @SerialName("route_type") val routeType: MLflowRouteType, + @SerialName("input_tokens") val inputTokens: Int? = null, + @SerialName("output_tokens") val outputTokens: Int? = null, + @SerialName("total_tokens") val totalTokens: Int? = null +) + +@Serializable +enum class MLflowRouteType { + @SerialName("llm/v1/completions") COMPLETIONS, + @SerialName("llm/v1/chat") CHAT, + @SerialName("llm/v1/embeddings") EMBEDDINGS +} + +private object XefChatMessageSerializable : + JsonTransformingSerializer>(ListSerializer(XefChatMessage.serializer())) { + override fun transformDeserialize(element: JsonElement): JsonElement = + if (element !is JsonArray) JsonArray(listOf(element)) else element +} + +private object MLflowChatMessageSerializable : + JsonTransformingSerializer>( + ListSerializer(MLflowChatMessage.serializer()) + ) { + override fun transformSerialize(element: JsonElement): JsonElement = + if (element is JsonArray) { + val jsonArray = element.jsonArray + if (jsonArray.size == 1) jsonArray[0] else jsonArray + } else element +} + +private object StringArraySerializable : + JsonTransformingSerializer>(ListSerializer(String.serializer())) { + override fun transformDeserialize(element: JsonElement): JsonElement = + if (element !is JsonArray) JsonArray(listOf(element)) else element + + override fun transformSerialize(element: JsonElement): JsonElement = + if (element is JsonArray) { + val jsonArray = element.jsonArray + if (jsonArray.size == 1) jsonArray[0] else jsonArray + } else element +} + +fun XefChatRequest.toMLflow(): MLflowChatRequest = + MLflowChatRequest( + messages.map { MLflowChatMessage(it.role, it.content) }, + temperature, + n, + stop, + maxTokens + ) + +fun XefEmbeddingsRequest.toMLflow(): MLflowEmbeddingsRequest = MLflowEmbeddingsRequest(input) + +fun MLflowChatResponse.toXef(): XefChatResponse = + XefChatResponse( + UUID.generateUUID().toString(), + "chat.completion", + getTimeMillis(), + metadata.model, + candidates.mapIndexed { index, candidate -> + XefResponseChoice( + index, + XefResponseMessage(candidate.message.role, candidate.message.content), + candidate.metadata.finishReason + ) + }, + metadata.toXef() + ) + +fun MLflowEmbeddingsResponse.toXef(): XefEmbeddingResponse = + XefEmbeddingResponse( + "list", + embeddings.mapIndexed { index, list -> XefEmbedding("embedding", index, list) }, + metadata.model, + metadata.toXef().copy(completionTokens = null) + ) + +private fun MLflowResponseMetadata.toXef(): XefResponseUsage = + XefResponseUsage(inputTokens, outputTokens, totalTokens) diff --git a/server/src/main/kotlin/com/xebia/functional/xef/server/http/routes/AIRoutes.kt b/server/src/main/kotlin/com/xebia/functional/xef/server/http/routes/AIRoutes.kt index 407c12430..ce67cc16e 100644 --- a/server/src/main/kotlin/com/xebia/functional/xef/server/http/routes/AIRoutes.kt +++ b/server/src/main/kotlin/com/xebia/functional/xef/server/http/routes/AIRoutes.kt @@ -62,6 +62,10 @@ fun Routing.aiRoutes(client: HttpClient) { } } +private val conflictingRequestHeaders = + listOf("Host", "Content-Type", "Content-Length", "Accept", "Accept-Encoding") +private val conflictingResponseHeaders = listOf("Content-Length") + private suspend fun HttpClient.makeRequest( call: ApplicationCall, url: String, @@ -75,8 +79,9 @@ private suspend fun HttpClient.makeRequest( method = HttpMethod.Post setBody(body) } - call.response.headers.copyFrom(response.headers, "Content-Length") - call.respond(response.status, response.readBytes()) + call.response.headers.copyFrom(response.headers) + // `response.bodyAsText()` is needed for triggering responsePipeline intercept + call.respond(response.status, response.bodyAsText()) } private suspend fun HttpClient.makeStreaming( @@ -91,24 +96,24 @@ private suspend fun HttpClient.makeStreaming( setBody(body) } .execute { httpResponse -> - call.response.headers.copyFrom(httpResponse.headers, "Content-Length") + call.response.headers.copyFrom(httpResponse.headers) call.respondOutputStream { httpResponse.bodyAsChannel().copyTo(this@respondOutputStream) } } } -private fun ResponseHeaders.copyFrom(headers: Headers, vararg filterOut: String) = +private fun ResponseHeaders.copyFrom(headers: Headers) = headers .entries() .filter { (key, _) -> !HttpHeaders.isUnsafe(key) } // setting unsafe headers results in exception - .filterNot { (key, _) -> filterOut.any { it.equals(key, true) } } + .filterNot { (key, _) -> conflictingResponseHeaders.any { it.equals(key, true) } } .forEach { (key, values) -> values.forEach { value -> this.appendIfAbsent(key, value) } } internal fun HeadersBuilder.copyFrom(headers: Headers) = headers - .filter { key, value -> !key.equals("HOST", ignoreCase = true) } - .forEach { key, values -> appendAll(key, values) } + .filter { key, _ -> !conflictingRequestHeaders.any { it.equals(key, true) } } + .forEach { key, values -> appendMissing(key, values) } private fun ApplicationCall.getProvider(): Provider = request.headers["xef-provider"]?.toProvider() ?: Provider.OPENAI