Skip to content

Commit

Permalink
MLflow model adapter (#517)
Browse files Browse the repository at this point in the history
* Updates adapters

* Completes mlflow model adapter

* Restores some unrelated changes with this PR
  • Loading branch information
fedefernandez authored Nov 6, 2023
1 parent fad4233 commit 3a8c11a
Show file tree
Hide file tree
Showing 5 changed files with 365 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import io.ktor.client.request.*
import io.ktor.http.*
import io.ktor.http.content.*
import io.ktor.util.*
import io.ktor.util.pipeline.*
import kotlinx.serialization.json.Json
import kotlinx.serialization.json.JsonObject
import kotlinx.serialization.json.jsonPrimitive
Expand All @@ -29,49 +28,39 @@ internal constructor(private val urlMap: Map<OpenAIPathType, Map<String, String>
ModelUriAdapterBuilder().apply(block).build()

override fun install(plugin: ModelUriAdapter, scope: HttpClient) {
installModelAuthAdapter(plugin, scope)
installModelUriAdapter(plugin, scope)
}

private fun readModelFromRequest(originalRequest: OutgoingContent.ByteArrayContent?): String? {
val requestBody = originalRequest?.bytes()?.toString(Charsets.UTF_8)
val json = requestBody?.let { Json.decodeFromString<JsonObject>(it) }
return json?.get("model")?.jsonPrimitive?.content
private fun readModelFromRequest(originalRequest: ByteArray): String? {
val requestBody = originalRequest.toString(Charsets.UTF_8)
val json = Json.decodeFromString<JsonObject>(requestBody)
return json["model"]?.jsonPrimitive?.content
}

private fun installModelAuthAdapter(plugin: ModelUriAdapter, scope: HttpClient) {
val adaptAuthRequestPhase = PipelinePhase("ModelAuthAdaptRequest")
scope.sendPipeline.insertPhaseAfter(HttpSendPipeline.State, adaptAuthRequestPhase)
scope.sendPipeline.intercept(adaptAuthRequestPhase) { content ->
private fun installModelUriAdapter(plugin: ModelUriAdapter, scope: HttpClient) {
scope.requestPipeline.intercept(HttpRequestPipeline.Transform) { content ->
val originalPath = OpenAIPathType.from(context.url.encodedPath) ?: return@intercept
if (plugin.isDefined(originalPath)) {
val originalRequest = content as? OutgoingContent.ByteArrayContent
if (originalRequest == null) {
plugin.logger.warn {
"""
|Can't adapt the model auth.
|The body type is: ${content::class}, with Content-Type: ${context.contentType()}.
|
|If you expect serialized body, please check that you have installed the corresponding
|plugin(like `ContentNegotiation`) and set `Content-Type` header."""
.trimMargin()
}
return@intercept
if (!plugin.isDefined(originalPath)) return@intercept
val model =
when (content) {
is OutgoingContent.ByteArrayContent -> readModelFromRequest(content.bytes())
is ByteArray -> readModelFromRequest(content)
else -> return@intercept
}
val model = readModelFromRequest(originalRequest)
val newURL = model?.let { plugin.findPath(originalPath, it) }
if (newURL == null) {
plugin.logger.info {
"Model auth didn't found a new url for path $originalPath and model $model"
}
} else {
val baseBuilder = URLBuilder(newURL).build()
context.url.set(
scheme = baseBuilder.protocol.name,
host = baseBuilder.host,
port = baseBuilder.port,
path = baseBuilder.encodedPath
)
val newURL = model?.let { plugin.findPath(originalPath, it) }
if (newURL == null)
plugin.logger.info { "New url for path $originalPath and model $model not found" }
else {
plugin.logger.info {
"Intercepting request for path $originalPath and model $model to $newURL"
}
val baseBuilder = URLBuilder(newURL).build()
context.url.set(
scheme = baseBuilder.protocol.name,
host = baseBuilder.host,
port = baseBuilder.port,
path = baseBuilder.encodedPath
)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
package com.xebia.functional.xef.server.http.client.mlflow

import com.xebia.functional.xef.server.http.client.OpenAIPathType
import io.github.oshai.kotlinlogging.KotlinLogging
import io.ktor.client.*
import io.ktor.client.call.*
import io.ktor.client.plugins.*
import io.ktor.client.request.*
import io.ktor.client.statement.*
import io.ktor.http.*
import io.ktor.http.content.*
import io.ktor.util.*
import io.ktor.util.pipeline.*
import io.ktor.util.reflect.*
import io.ktor.utils.io.core.*
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.encodeToString
import kotlinx.serialization.json.Json

class MLflowModelAdapter
internal constructor(private val mappedRequests: Map<String, OpenAIPathType>) {

val logger = KotlinLogging.logger {}

fun mappedType(path: String): OpenAIPathType? = mappedRequests[path]

companion object : HttpClientPlugin<MLflowModelAdapterBuilder, MLflowModelAdapter> {

@OptIn(ExperimentalSerializationApi::class)
private val json = Json {
explicitNulls = false
ignoreUnknownKeys = true
}

override val key: AttributeKey<MLflowModelAdapter> = AttributeKey("MLflowAdapter")

override fun prepare(block: MLflowModelAdapterBuilder.() -> Unit): MLflowModelAdapter =
MLflowModelAdapterBuilder().apply(block).build()

override fun install(plugin: MLflowModelAdapter, scope: HttpClient) {
installMLflowRequestAdapter(plugin, scope)
installMLflowResponseAdapter(plugin, scope)
}

private suspend inline fun <reified R1, reified R2> PipelineContext<Any, HttpRequestBuilder>
.update(originalRequest: ByteArray, toMLflow: R1.() -> R2) {
val stringRequestBody = originalRequest.toString(Charsets.UTF_8)
val requestData = json.decodeFromString<R1>(stringRequestBody)
val newRequest =
TextContent(json.encodeToString(requestData.toMLflow()), ContentType.Application.Json)
proceedWith(newRequest)
}

private suspend inline fun <reified R1, reified R2> PipelineContext<
HttpResponseContainer, HttpClientCall
>
.update(typeInfo: TypeInfo, contentResponse: ByteReadPacket, toXef: R1.() -> R2) {
val stringResponseBody = contentResponse.readText()
val responseData = json.decodeFromString<R1>(stringResponseBody)
val newResponse =
ByteReadPacket(json.encodeToString(responseData.toXef()).toByteArray(Charsets.UTF_8))
val response = HttpResponseContainer(typeInfo, newResponse)
proceedWith(response)
}

private fun installMLflowRequestAdapter(plugin: MLflowModelAdapter, scope: HttpClient) {
scope.requestPipeline.intercept(HttpRequestPipeline.Transform) { content ->
val originalPath = plugin.mappedType(context.url.buildString()) ?: return@intercept
val originalRequest =
when (content) {
is OutgoingContent.ByteArrayContent -> content.bytes()
is ByteArray -> content
else -> return@intercept
}
when (originalPath) {
OpenAIPathType.CHAT -> {
plugin.logger.info { "Intercepting chat request for path $originalPath" }
update(originalRequest, XefChatRequest::toMLflow)
}
OpenAIPathType.EMBEDDINGS -> {
plugin.logger.info { "Intercepting embeddings request for path $originalPath" }
update(originalRequest, XefEmbeddingsRequest::toMLflow)
}
else -> {
plugin.logger.warn { "$originalPath not supported" }
return@intercept
}
}
}
}

private fun installMLflowResponseAdapter(plugin: MLflowModelAdapter, scope: HttpClient) {
scope.responsePipeline.intercept(HttpResponsePipeline.Transform) { content ->
val originalPath = plugin.mappedType(context.request.url.toString()) ?: return@intercept
val contentResponse = content.response as? ByteReadPacket ?: return@intercept
when (originalPath) {
OpenAIPathType.CHAT -> {
plugin.logger.info { "Intercepting chat response for path $originalPath" }
update(content.expectedType, contentResponse, MLflowChatResponse::toXef)
}
OpenAIPathType.EMBEDDINGS -> {
plugin.logger.info { "Intercepting embeddings response for path $originalPath" }
update(content.expectedType, contentResponse, MLflowEmbeddingsResponse::toXef)
}
else -> {
plugin.logger.warn { "$originalPath not supported" }
return@intercept
}
}
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package com.xebia.functional.xef.server.http.client.mlflow

import com.xebia.functional.xef.server.http.client.OpenAIPathType

class MLflowModelAdapterBuilder {

private var pathTypeMap: Map<String, OpenAIPathType> = LinkedHashMap()

fun addToPath(path: String, pathType: OpenAIPathType) {
this.pathTypeMap += mapOf(path to pathType)
}

internal fun build(): MLflowModelAdapter = MLflowModelAdapter(pathTypeMap)
}
Loading

0 comments on commit 3a8c11a

Please sign in to comment.