Skip to content

Commit

Permalink
Completes mlflow model adapter
Browse files Browse the repository at this point in the history
  • Loading branch information
fedefernandez committed Nov 3, 2023
1 parent 428316c commit d881c5a
Show file tree
Hide file tree
Showing 10 changed files with 308 additions and 216 deletions.
11 changes: 0 additions & 11 deletions server/src/main/kotlin/com/xebia/functional/xef/server/Server.kt
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@ import com.typesafe.config.ConfigFactory
import com.xebia.functional.xef.server.db.psql.Migrate
import com.xebia.functional.xef.server.db.psql.XefDatabaseConfig
import com.xebia.functional.xef.server.exceptions.exceptionsHandler
import com.xebia.functional.xef.server.http.client.ModelUriAdapter
import com.xebia.functional.xef.server.http.client.OpenAIPathType
import com.xebia.functional.xef.server.http.client.mlflow.MLflowModelAdapter
import com.xebia.functional.xef.server.http.routes.*
import com.xebia.functional.xef.server.services.PostgresVectorStoreService
import com.xebia.functional.xef.server.services.RepositoryService
Expand All @@ -19,9 +16,7 @@ import io.ktor.client.engine.cio.*
import io.ktor.client.plugins.auth.*
import io.ktor.client.plugins.contentnegotiation.ContentNegotiation as ClientContentNegotiation
import io.ktor.client.plugins.logging.*
import io.ktor.client.request.*
import io.ktor.http.*
import io.ktor.http.content.*
import io.ktor.serialization.kotlinx.json.*
import io.ktor.server.application.*
import io.ktor.server.auth.*
Expand Down Expand Up @@ -65,12 +60,6 @@ object Server {
install(Auth)
install(Logging) { level = LogLevel.INFO }
install(ClientContentNegotiation)
install(ModelUriAdapter) {
addToPath(OpenAIPathType.EMBEDDINGS, "ojete/calor" to "https://ca0f47a7-ade7-430a-8735-e1cea32ac960.mock.pstmn.io/http://127.0.0.1:5000/gateway/embeddings/invocations")
}
install(MLflowModelAdapter) {
addToPath("https://ca0f47a7-ade7-430a-8735-e1cea32ac960.mock.pstmn.io/http://127.0.0.1:5000/gateway/embeddings/invocations", OpenAIPathType.EMBEDDINGS)
}
}

server(factory = Netty, port = 8081, host = "0.0.0.0") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,54 +11,58 @@ import kotlinx.serialization.json.Json
import kotlinx.serialization.json.JsonObject
import kotlinx.serialization.json.jsonPrimitive

class ModelUriAdapter internal constructor(private val urlMap: Map<OpenAIPathType, Map<String, String>>) {
class ModelUriAdapter
internal constructor(private val urlMap: Map<OpenAIPathType, Map<String, String>>) {

val logger = KotlinLogging.logger {}
val logger = KotlinLogging.logger {}

fun isDefined(path: OpenAIPathType): Boolean = urlMap.containsKey(path)
fun isDefined(path: OpenAIPathType): Boolean = urlMap.containsKey(path)

fun findPath(path: OpenAIPathType, model: String): String? = urlMap[path]?.get(model)
fun findPath(path: OpenAIPathType, model: String): String? = urlMap[path]?.get(model)

companion object : HttpClientPlugin<ModelUriAdapterBuilder, ModelUriAdapter> {
companion object : HttpClientPlugin<ModelUriAdapterBuilder, ModelUriAdapter> {

override val key: AttributeKey<ModelUriAdapter> = AttributeKey("ModelAuthAdapter")
override val key: AttributeKey<ModelUriAdapter> = AttributeKey("ModelAuthAdapter")

override fun prepare(block: ModelUriAdapterBuilder.() -> Unit): ModelUriAdapter =
ModelUriAdapterBuilder().apply(block).build()
override fun prepare(block: ModelUriAdapterBuilder.() -> Unit): ModelUriAdapter =
ModelUriAdapterBuilder().apply(block).build()

override fun install(plugin: ModelUriAdapter, scope: HttpClient) {
installModelUriAdapter(plugin, scope)
}
override fun install(plugin: ModelUriAdapter, scope: HttpClient) {
installModelUriAdapter(plugin, scope)
}

private fun readModelFromRequest(originalRequest: ByteArray): String? {
val requestBody = originalRequest.toString(Charsets.UTF_8)
val json = Json.decodeFromString<JsonObject>(requestBody)
return json["model"]?.jsonPrimitive?.content
}
private fun readModelFromRequest(originalRequest: ByteArray): String? {
val requestBody = originalRequest.toString(Charsets.UTF_8)
val json = Json.decodeFromString<JsonObject>(requestBody)
return json["model"]?.jsonPrimitive?.content
}

private fun installModelUriAdapter(plugin: ModelUriAdapter, scope: HttpClient) {
scope.requestPipeline.intercept(HttpRequestPipeline.Transform) { content ->
val originalPath = OpenAIPathType.from(context.url.encodedPath) ?: return@intercept
if (!plugin.isDefined(originalPath)) return@intercept
val model = when (content) {
is OutgoingContent.ByteArrayContent -> readModelFromRequest(content.bytes())
is ByteArray -> readModelFromRequest(content)
else -> return@intercept
}
val newURL = model?.let { plugin.findPath(originalPath, it) }
if (newURL == null) plugin.logger.info { "New url for path $originalPath and model $model not found" }
else {
plugin.logger.info { "Intercepting request for path $originalPath and model $model to $newURL" }
val baseBuilder = URLBuilder(newURL).build()
context.url.set(
scheme = baseBuilder.protocol.name,
host = baseBuilder.host,
port = baseBuilder.port,
path = baseBuilder.encodedPath
)
}
}
private fun installModelUriAdapter(plugin: ModelUriAdapter, scope: HttpClient) {
scope.requestPipeline.intercept(HttpRequestPipeline.Transform) { content ->
val originalPath = OpenAIPathType.from(context.url.encodedPath) ?: return@intercept
if (!plugin.isDefined(originalPath)) return@intercept
val model =
when (content) {
is OutgoingContent.ByteArrayContent -> readModelFromRequest(content.bytes())
is ByteArray -> readModelFromRequest(content)
else -> return@intercept
}
val newURL = model?.let { plugin.findPath(originalPath, it) }
if (newURL == null)
plugin.logger.info { "New url for path $originalPath and model $model not found" }
else {
plugin.logger.info {
"Intercepting request for path $originalPath and model $model to $newURL"
}
val baseBuilder = URLBuilder(newURL).build()
context.url.set(
scheme = baseBuilder.protocol.name,
host = baseBuilder.host,
port = baseBuilder.port,
path = baseBuilder.encodedPath
)
}

}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ class ModelUriAdapterBuilder {
}

internal fun build(): ModelUriAdapter = ModelUriAdapter(pathMap)
}
}
Original file line number Diff line number Diff line change
@@ -1,89 +1,113 @@
package com.xebia.functional.xef.server.http.client.mlflow

import com.xebia.functional.xef.server.http.client.ModelUriAdapter
import com.xebia.functional.xef.server.http.client.OpenAIPathType
import io.github.oshai.kotlinlogging.KotlinLogging
import io.ktor.client.*
import io.ktor.client.call.*
import io.ktor.client.plugins.*
import io.ktor.client.request.*
import io.ktor.client.statement.*
import io.ktor.http.*
import io.ktor.http.content.*
import io.ktor.util.*
import io.ktor.util.pipeline.*
import io.ktor.util.reflect.*
import io.ktor.utils.io.core.*
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.encodeToString
import kotlinx.serialization.json.Json

class MLflowModelAdapter internal constructor(private val mappedRequests: Map<String, OpenAIPathType>) {
class MLflowModelAdapter
internal constructor(private val mappedRequests: Map<String, OpenAIPathType>) {

val logger = KotlinLogging.logger {}
val logger = KotlinLogging.logger {}

fun mappedType(path: String): OpenAIPathType? = mappedRequests[path]
fun mappedType(path: String): OpenAIPathType? = mappedRequests[path]

companion object : HttpClientPlugin<MLflowModelAdapterBuilder, MLflowModelAdapter> {
companion object : HttpClientPlugin<MLflowModelAdapterBuilder, MLflowModelAdapter> {

@OptIn(ExperimentalSerializationApi::class)
private val json = Json {
explicitNulls = false
ignoreUnknownKeys = true
}
@OptIn(ExperimentalSerializationApi::class)
private val json = Json {
explicitNulls = false
ignoreUnknownKeys = true
}

override val key: AttributeKey<MLflowModelAdapter> = AttributeKey("MLflowAdapter")
override val key: AttributeKey<MLflowModelAdapter> = AttributeKey("MLflowAdapter")

override fun prepare(block: MLflowModelAdapterBuilder.() -> Unit): MLflowModelAdapter =
MLflowModelAdapterBuilder().apply(block).build()
override fun prepare(block: MLflowModelAdapterBuilder.() -> Unit): MLflowModelAdapter =
MLflowModelAdapterBuilder().apply(block).build()

override fun install(plugin: MLflowModelAdapter, scope: HttpClient) {
installMLflowRequestAdapter(plugin, scope)
installMLflowResponseAdapter(plugin, scope)
}
override fun install(plugin: MLflowModelAdapter, scope: HttpClient) {
installMLflowRequestAdapter(plugin, scope)
installMLflowResponseAdapter(plugin, scope)
}

private fun installMLflowRequestAdapter(plugin: MLflowModelAdapter, scope: HttpClient) {
scope.requestPipeline.intercept(HttpRequestPipeline.Transform) { content ->
val originalPath = plugin.mappedType(context.url.buildString()) ?: return@intercept
val originalRequest = when (content) {
is OutgoingContent.ByteArrayContent -> content.bytes()
is ByteArray -> content
else -> return@intercept
}
when (originalPath) {
OpenAIPathType.CHAT -> TODO()
OpenAIPathType.EMBEDDINGS -> {
plugin.logger.info { "Intercepting request for path $originalPath" }
val stringRequestBody = originalRequest.toString(Charsets.UTF_8)
val requestData = json.decodeFromString<XefEmbeddingsRequest>(stringRequestBody)
val newRequest = TextContent(json.encodeToString(requestData.toMLflow()), ContentType.Application.Json)
proceedWith(newRequest)
}
else -> {
plugin.logger.warn { "$originalPath not supported" }
return@intercept
}
}
}
}
private suspend inline fun <reified R1, reified R2> PipelineContext<Any, HttpRequestBuilder>
.update(originalRequest: ByteArray, toMLflow: R1.() -> R2) {
val stringRequestBody = originalRequest.toString(Charsets.UTF_8)
val requestData = json.decodeFromString<R1>(stringRequestBody)
val newRequest =
TextContent(json.encodeToString(requestData.toMLflow()), ContentType.Application.Json)
proceedWith(newRequest)
}

private suspend inline fun <reified R1, reified R2> PipelineContext<
HttpResponseContainer, HttpClientCall
>
.update(typeInfo: TypeInfo, contentResponse: ByteReadPacket, toXef: R1.() -> R2) {
val stringResponseBody = contentResponse.readText()
val responseData = json.decodeFromString<R1>(stringResponseBody)
val newResponse =
ByteReadPacket(json.encodeToString(responseData.toXef()).toByteArray(Charsets.UTF_8))
val response = HttpResponseContainer(typeInfo, newResponse)
proceedWith(response)
}

private fun installMLflowResponseAdapter(plugin: MLflowModelAdapter, scope: HttpClient) {
scope.responsePipeline.intercept(HttpResponsePipeline.Transform) { content ->
val originalPath = plugin.mappedType(context.request.url.toString()) ?: return@intercept
val contentResponse = content.response as? ByteReadPacket ?: return@intercept
when (originalPath) {
OpenAIPathType.CHAT -> TODO()
OpenAIPathType.EMBEDDINGS -> {
val stringResponseBody = contentResponse.readText()
val responseData = json.decodeFromString<MLflowEmbeddingsResponse>(stringResponseBody)
val newResponse = ByteReadPacket(json.encodeToString(responseData.toXef()).toByteArray(Charsets.UTF_8))
val response = HttpResponseContainer(content.expectedType, newResponse)
proceedWith(response)
}
else -> {
plugin.logger.warn { "$originalPath not supported" }
return@intercept
}
}
}
private fun installMLflowRequestAdapter(plugin: MLflowModelAdapter, scope: HttpClient) {
scope.requestPipeline.intercept(HttpRequestPipeline.Transform) { content ->
val originalPath = plugin.mappedType(context.url.buildString()) ?: return@intercept
val originalRequest =
when (content) {
is OutgoingContent.ByteArrayContent -> content.bytes()
is ByteArray -> content
else -> return@intercept
}
when (originalPath) {
OpenAIPathType.CHAT -> {
plugin.logger.info { "Intercepting chat request for path $originalPath" }
update(originalRequest, XefChatRequest::toMLflow)
}
OpenAIPathType.EMBEDDINGS -> {
plugin.logger.info { "Intercepting embeddings request for path $originalPath" }
update(originalRequest, XefEmbeddingsRequest::toMLflow)
}
else -> {
plugin.logger.warn { "$originalPath not supported" }
return@intercept
}
}
}
}

}
private fun installMLflowResponseAdapter(plugin: MLflowModelAdapter, scope: HttpClient) {
scope.responsePipeline.intercept(HttpResponsePipeline.Transform) { content ->
val originalPath = plugin.mappedType(context.request.url.toString()) ?: return@intercept
val contentResponse = content.response as? ByteReadPacket ?: return@intercept
when (originalPath) {
OpenAIPathType.CHAT -> {
plugin.logger.info { "Intercepting chat response for path $originalPath" }
update(content.expectedType, contentResponse, MLflowChatResponse::toXef)
}
OpenAIPathType.EMBEDDINGS -> {
plugin.logger.info { "Intercepting embeddings response for path $originalPath" }
update(content.expectedType, contentResponse, MLflowEmbeddingsResponse::toXef)
}
else -> {
plugin.logger.warn { "$originalPath not supported" }
return@intercept
}
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ import com.xebia.functional.xef.server.http.client.OpenAIPathType

class MLflowModelAdapterBuilder {

private var pathTypeMap: Map<String, OpenAIPathType> = LinkedHashMap()
private var pathTypeMap: Map<String, OpenAIPathType> = LinkedHashMap()

fun addToPath(path: String, pathType: OpenAIPathType) {
this.pathTypeMap += mapOf(path to pathType)
}
fun addToPath(path: String, pathType: OpenAIPathType) {
this.pathTypeMap += mapOf(path to pathType)
}

internal fun build(): MLflowModelAdapter = MLflowModelAdapter(pathTypeMap)
}
internal fun build(): MLflowModelAdapter = MLflowModelAdapter(pathTypeMap)
}
Loading

0 comments on commit d881c5a

Please sign in to comment.