diff --git a/server/src/main/kotlin/com/xebia/functional/xef/server/Server.kt b/server/src/main/kotlin/com/xebia/functional/xef/server/Server.kt index 080f4d937..1d809c42c 100644 --- a/server/src/main/kotlin/com/xebia/functional/xef/server/Server.kt +++ b/server/src/main/kotlin/com/xebia/functional/xef/server/Server.kt @@ -7,9 +7,6 @@ import com.typesafe.config.ConfigFactory import com.xebia.functional.xef.server.db.psql.Migrate import com.xebia.functional.xef.server.db.psql.XefDatabaseConfig import com.xebia.functional.xef.server.exceptions.exceptionsHandler -import com.xebia.functional.xef.server.http.client.ModelUriAdapter -import com.xebia.functional.xef.server.http.client.OpenAIPathType -import com.xebia.functional.xef.server.http.client.mlflow.MLflowModelAdapter import com.xebia.functional.xef.server.http.routes.* import com.xebia.functional.xef.server.services.PostgresVectorStoreService import com.xebia.functional.xef.server.services.RepositoryService @@ -19,9 +16,7 @@ import io.ktor.client.engine.cio.* import io.ktor.client.plugins.auth.* import io.ktor.client.plugins.contentnegotiation.ContentNegotiation as ClientContentNegotiation import io.ktor.client.plugins.logging.* -import io.ktor.client.request.* import io.ktor.http.* -import io.ktor.http.content.* import io.ktor.serialization.kotlinx.json.* import io.ktor.server.application.* import io.ktor.server.auth.* @@ -65,12 +60,6 @@ object Server { install(Auth) install(Logging) { level = LogLevel.INFO } install(ClientContentNegotiation) - install(ModelUriAdapter) { - addToPath(OpenAIPathType.EMBEDDINGS, "ojete/calor" to "https://ca0f47a7-ade7-430a-8735-e1cea32ac960.mock.pstmn.io/http://127.0.0.1:5000/gateway/embeddings/invocations") - } - install(MLflowModelAdapter) { - addToPath("https://ca0f47a7-ade7-430a-8735-e1cea32ac960.mock.pstmn.io/http://127.0.0.1:5000/gateway/embeddings/invocations", OpenAIPathType.EMBEDDINGS) - } } server(factory = Netty, port = 8081, host = "0.0.0.0") { diff --git a/server/src/main/kotlin/com/xebia/functional/xef/server/http/client/ModelUriAdapter.kt b/server/src/main/kotlin/com/xebia/functional/xef/server/http/client/ModelUriAdapter.kt index 0d9c977c5..9d256d74f 100644 --- a/server/src/main/kotlin/com/xebia/functional/xef/server/http/client/ModelUriAdapter.kt +++ b/server/src/main/kotlin/com/xebia/functional/xef/server/http/client/ModelUriAdapter.kt @@ -11,54 +11,58 @@ import kotlinx.serialization.json.Json import kotlinx.serialization.json.JsonObject import kotlinx.serialization.json.jsonPrimitive -class ModelUriAdapter internal constructor(private val urlMap: Map>) { +class ModelUriAdapter +internal constructor(private val urlMap: Map>) { - val logger = KotlinLogging.logger {} + val logger = KotlinLogging.logger {} - fun isDefined(path: OpenAIPathType): Boolean = urlMap.containsKey(path) + fun isDefined(path: OpenAIPathType): Boolean = urlMap.containsKey(path) - fun findPath(path: OpenAIPathType, model: String): String? = urlMap[path]?.get(model) + fun findPath(path: OpenAIPathType, model: String): String? = urlMap[path]?.get(model) - companion object : HttpClientPlugin { + companion object : HttpClientPlugin { - override val key: AttributeKey = AttributeKey("ModelAuthAdapter") + override val key: AttributeKey = AttributeKey("ModelAuthAdapter") - override fun prepare(block: ModelUriAdapterBuilder.() -> Unit): ModelUriAdapter = - ModelUriAdapterBuilder().apply(block).build() + override fun prepare(block: ModelUriAdapterBuilder.() -> Unit): ModelUriAdapter = + ModelUriAdapterBuilder().apply(block).build() - override fun install(plugin: ModelUriAdapter, scope: HttpClient) { - installModelUriAdapter(plugin, scope) - } + override fun install(plugin: ModelUriAdapter, scope: HttpClient) { + installModelUriAdapter(plugin, scope) + } - private fun readModelFromRequest(originalRequest: ByteArray): String? { - val requestBody = originalRequest.toString(Charsets.UTF_8) - val json = Json.decodeFromString(requestBody) - return json["model"]?.jsonPrimitive?.content - } + private fun readModelFromRequest(originalRequest: ByteArray): String? { + val requestBody = originalRequest.toString(Charsets.UTF_8) + val json = Json.decodeFromString(requestBody) + return json["model"]?.jsonPrimitive?.content + } - private fun installModelUriAdapter(plugin: ModelUriAdapter, scope: HttpClient) { - scope.requestPipeline.intercept(HttpRequestPipeline.Transform) { content -> - val originalPath = OpenAIPathType.from(context.url.encodedPath) ?: return@intercept - if (!plugin.isDefined(originalPath)) return@intercept - val model = when (content) { - is OutgoingContent.ByteArrayContent -> readModelFromRequest(content.bytes()) - is ByteArray -> readModelFromRequest(content) - else -> return@intercept - } - val newURL = model?.let { plugin.findPath(originalPath, it) } - if (newURL == null) plugin.logger.info { "New url for path $originalPath and model $model not found" } - else { - plugin.logger.info { "Intercepting request for path $originalPath and model $model to $newURL" } - val baseBuilder = URLBuilder(newURL).build() - context.url.set( - scheme = baseBuilder.protocol.name, - host = baseBuilder.host, - port = baseBuilder.port, - path = baseBuilder.encodedPath - ) - } - } + private fun installModelUriAdapter(plugin: ModelUriAdapter, scope: HttpClient) { + scope.requestPipeline.intercept(HttpRequestPipeline.Transform) { content -> + val originalPath = OpenAIPathType.from(context.url.encodedPath) ?: return@intercept + if (!plugin.isDefined(originalPath)) return@intercept + val model = + when (content) { + is OutgoingContent.ByteArrayContent -> readModelFromRequest(content.bytes()) + is ByteArray -> readModelFromRequest(content) + else -> return@intercept + } + val newURL = model?.let { plugin.findPath(originalPath, it) } + if (newURL == null) + plugin.logger.info { "New url for path $originalPath and model $model not found" } + else { + plugin.logger.info { + "Intercepting request for path $originalPath and model $model to $newURL" + } + val baseBuilder = URLBuilder(newURL).build() + context.url.set( + scheme = baseBuilder.protocol.name, + host = baseBuilder.host, + port = baseBuilder.port, + path = baseBuilder.encodedPath + ) } - + } } -} \ No newline at end of file + } +} diff --git a/server/src/main/kotlin/com/xebia/functional/xef/server/http/client/ModelUriAdapterBuilder.kt b/server/src/main/kotlin/com/xebia/functional/xef/server/http/client/ModelUriAdapterBuilder.kt index 5c4f7d69c..a2bda8039 100644 --- a/server/src/main/kotlin/com/xebia/functional/xef/server/http/client/ModelUriAdapterBuilder.kt +++ b/server/src/main/kotlin/com/xebia/functional/xef/server/http/client/ModelUriAdapterBuilder.kt @@ -14,4 +14,4 @@ class ModelUriAdapterBuilder { } internal fun build(): ModelUriAdapter = ModelUriAdapter(pathMap) -} \ No newline at end of file +} diff --git a/server/src/main/kotlin/com/xebia/functional/xef/server/http/client/mlflow/MLflowModelAdapter.kt b/server/src/main/kotlin/com/xebia/functional/xef/server/http/client/mlflow/MLflowModelAdapter.kt index 5ec96d007..17ad79ec0 100644 --- a/server/src/main/kotlin/com/xebia/functional/xef/server/http/client/mlflow/MLflowModelAdapter.kt +++ b/server/src/main/kotlin/com/xebia/functional/xef/server/http/client/mlflow/MLflowModelAdapter.kt @@ -1,89 +1,113 @@ package com.xebia.functional.xef.server.http.client.mlflow -import com.xebia.functional.xef.server.http.client.ModelUriAdapter import com.xebia.functional.xef.server.http.client.OpenAIPathType import io.github.oshai.kotlinlogging.KotlinLogging import io.ktor.client.* +import io.ktor.client.call.* import io.ktor.client.plugins.* import io.ktor.client.request.* import io.ktor.client.statement.* import io.ktor.http.* import io.ktor.http.content.* import io.ktor.util.* +import io.ktor.util.pipeline.* +import io.ktor.util.reflect.* import io.ktor.utils.io.core.* import kotlinx.serialization.ExperimentalSerializationApi import kotlinx.serialization.encodeToString import kotlinx.serialization.json.Json -class MLflowModelAdapter internal constructor(private val mappedRequests: Map) { +class MLflowModelAdapter +internal constructor(private val mappedRequests: Map) { - val logger = KotlinLogging.logger {} + val logger = KotlinLogging.logger {} - fun mappedType(path: String): OpenAIPathType? = mappedRequests[path] + fun mappedType(path: String): OpenAIPathType? = mappedRequests[path] - companion object : HttpClientPlugin { + companion object : HttpClientPlugin { - @OptIn(ExperimentalSerializationApi::class) - private val json = Json { - explicitNulls = false - ignoreUnknownKeys = true - } + @OptIn(ExperimentalSerializationApi::class) + private val json = Json { + explicitNulls = false + ignoreUnknownKeys = true + } - override val key: AttributeKey = AttributeKey("MLflowAdapter") + override val key: AttributeKey = AttributeKey("MLflowAdapter") - override fun prepare(block: MLflowModelAdapterBuilder.() -> Unit): MLflowModelAdapter = - MLflowModelAdapterBuilder().apply(block).build() + override fun prepare(block: MLflowModelAdapterBuilder.() -> Unit): MLflowModelAdapter = + MLflowModelAdapterBuilder().apply(block).build() - override fun install(plugin: MLflowModelAdapter, scope: HttpClient) { - installMLflowRequestAdapter(plugin, scope) - installMLflowResponseAdapter(plugin, scope) - } + override fun install(plugin: MLflowModelAdapter, scope: HttpClient) { + installMLflowRequestAdapter(plugin, scope) + installMLflowResponseAdapter(plugin, scope) + } - private fun installMLflowRequestAdapter(plugin: MLflowModelAdapter, scope: HttpClient) { - scope.requestPipeline.intercept(HttpRequestPipeline.Transform) { content -> - val originalPath = plugin.mappedType(context.url.buildString()) ?: return@intercept - val originalRequest = when (content) { - is OutgoingContent.ByteArrayContent -> content.bytes() - is ByteArray -> content - else -> return@intercept - } - when (originalPath) { - OpenAIPathType.CHAT -> TODO() - OpenAIPathType.EMBEDDINGS -> { - plugin.logger.info { "Intercepting request for path $originalPath" } - val stringRequestBody = originalRequest.toString(Charsets.UTF_8) - val requestData = json.decodeFromString(stringRequestBody) - val newRequest = TextContent(json.encodeToString(requestData.toMLflow()), ContentType.Application.Json) - proceedWith(newRequest) - } - else -> { - plugin.logger.warn { "$originalPath not supported" } - return@intercept - } - } - } - } + private suspend inline fun PipelineContext + .update(originalRequest: ByteArray, toMLflow: R1.() -> R2) { + val stringRequestBody = originalRequest.toString(Charsets.UTF_8) + val requestData = json.decodeFromString(stringRequestBody) + val newRequest = + TextContent(json.encodeToString(requestData.toMLflow()), ContentType.Application.Json) + proceedWith(newRequest) + } + + private suspend inline fun PipelineContext< + HttpResponseContainer, HttpClientCall + > + .update(typeInfo: TypeInfo, contentResponse: ByteReadPacket, toXef: R1.() -> R2) { + val stringResponseBody = contentResponse.readText() + val responseData = json.decodeFromString(stringResponseBody) + val newResponse = + ByteReadPacket(json.encodeToString(responseData.toXef()).toByteArray(Charsets.UTF_8)) + val response = HttpResponseContainer(typeInfo, newResponse) + proceedWith(response) + } - private fun installMLflowResponseAdapter(plugin: MLflowModelAdapter, scope: HttpClient) { - scope.responsePipeline.intercept(HttpResponsePipeline.Transform) { content -> - val originalPath = plugin.mappedType(context.request.url.toString()) ?: return@intercept - val contentResponse = content.response as? ByteReadPacket ?: return@intercept - when (originalPath) { - OpenAIPathType.CHAT -> TODO() - OpenAIPathType.EMBEDDINGS -> { - val stringResponseBody = contentResponse.readText() - val responseData = json.decodeFromString(stringResponseBody) - val newResponse = ByteReadPacket(json.encodeToString(responseData.toXef()).toByteArray(Charsets.UTF_8)) - val response = HttpResponseContainer(content.expectedType, newResponse) - proceedWith(response) - } - else -> { - plugin.logger.warn { "$originalPath not supported" } - return@intercept - } - } - } + private fun installMLflowRequestAdapter(plugin: MLflowModelAdapter, scope: HttpClient) { + scope.requestPipeline.intercept(HttpRequestPipeline.Transform) { content -> + val originalPath = plugin.mappedType(context.url.buildString()) ?: return@intercept + val originalRequest = + when (content) { + is OutgoingContent.ByteArrayContent -> content.bytes() + is ByteArray -> content + else -> return@intercept + } + when (originalPath) { + OpenAIPathType.CHAT -> { + plugin.logger.info { "Intercepting chat request for path $originalPath" } + update(originalRequest, XefChatRequest::toMLflow) + } + OpenAIPathType.EMBEDDINGS -> { + plugin.logger.info { "Intercepting embeddings request for path $originalPath" } + update(originalRequest, XefEmbeddingsRequest::toMLflow) + } + else -> { + plugin.logger.warn { "$originalPath not supported" } + return@intercept + } } + } } -} \ No newline at end of file + private fun installMLflowResponseAdapter(plugin: MLflowModelAdapter, scope: HttpClient) { + scope.responsePipeline.intercept(HttpResponsePipeline.Transform) { content -> + val originalPath = plugin.mappedType(context.request.url.toString()) ?: return@intercept + val contentResponse = content.response as? ByteReadPacket ?: return@intercept + when (originalPath) { + OpenAIPathType.CHAT -> { + plugin.logger.info { "Intercepting chat response for path $originalPath" } + update(content.expectedType, contentResponse, MLflowChatResponse::toXef) + } + OpenAIPathType.EMBEDDINGS -> { + plugin.logger.info { "Intercepting embeddings response for path $originalPath" } + update(content.expectedType, contentResponse, MLflowEmbeddingsResponse::toXef) + } + else -> { + plugin.logger.warn { "$originalPath not supported" } + return@intercept + } + } + } + } + } +} diff --git a/server/src/main/kotlin/com/xebia/functional/xef/server/http/client/mlflow/MLflowModelAdapterBuilder.kt b/server/src/main/kotlin/com/xebia/functional/xef/server/http/client/mlflow/MLflowModelAdapterBuilder.kt index 8df705e23..62a49dee3 100644 --- a/server/src/main/kotlin/com/xebia/functional/xef/server/http/client/mlflow/MLflowModelAdapterBuilder.kt +++ b/server/src/main/kotlin/com/xebia/functional/xef/server/http/client/mlflow/MLflowModelAdapterBuilder.kt @@ -4,11 +4,11 @@ import com.xebia.functional.xef.server.http.client.OpenAIPathType class MLflowModelAdapterBuilder { - private var pathTypeMap: Map = LinkedHashMap() + private var pathTypeMap: Map = LinkedHashMap() - fun addToPath(path: String, pathType: OpenAIPathType) { - this.pathTypeMap += mapOf(path to pathType) - } + fun addToPath(path: String, pathType: OpenAIPathType) { + this.pathTypeMap += mapOf(path to pathType) + } - internal fun build(): MLflowModelAdapter = MLflowModelAdapter(pathTypeMap) -} \ No newline at end of file + internal fun build(): MLflowModelAdapter = MLflowModelAdapter(pathTypeMap) +} diff --git a/server/src/main/kotlin/com/xebia/functional/xef/server/http/client/mlflow/OpenAIModels.kt b/server/src/main/kotlin/com/xebia/functional/xef/server/http/client/mlflow/OpenAIModels.kt index ae8a1a0c2..3902ba800 100644 --- a/server/src/main/kotlin/com/xebia/functional/xef/server/http/client/mlflow/OpenAIModels.kt +++ b/server/src/main/kotlin/com/xebia/functional/xef/server/http/client/mlflow/OpenAIModels.kt @@ -1,5 +1,6 @@ package com.xebia.functional.xef.server.http.client.mlflow +import io.ktor.util.date.* import kotlinx.serialization.SerialName import kotlinx.serialization.Serializable import kotlinx.serialization.builtins.ListSerializer @@ -8,92 +9,192 @@ import kotlinx.serialization.json.JsonArray import kotlinx.serialization.json.JsonElement import kotlinx.serialization.json.JsonTransformingSerializer import kotlinx.serialization.json.jsonArray +import kotlinx.uuid.UUID +import kotlinx.uuid.generateUUID + +@Serializable +data class XefChatRequest( + val model: String, + @Serializable(with = XefChatMessageSerializable::class) val messages: List, + val temperature: Double? = null, + val n: Int? = null, + val stop: List? = null, + @SerialName("max_tokens") val maxTokens: Int? = null, + val stream: Boolean? = false +) + +@Serializable data class XefChatMessage(val role: String, val content: String) @Serializable data class XefEmbeddingsRequest( - val model: String, - @Serializable(with = StringArraySerializable::class) val input: List, - val encodingFormat: XefEncodingFormat? = null + val model: String, + @Serializable(with = StringArraySerializable::class) val input: List, + val encodingFormat: XefEncodingFormat? = null ) @Serializable enum class XefEncodingFormat { - @SerialName("float") FLOAT, - @SerialName("base64") BASE64 + @SerialName("float") FLOAT, + @SerialName("base64") BASE64 } @Serializable -data class MLflowEmbeddingsRequest( - @Serializable(with = StringArraySerializable::class) val text: List +data class XefChatResponse( + val id: String, + @SerialName("object") val objectModel: String, + val created: Long, + val model: String, + val choices: List, + val usage: XefResponseUsage ) @Serializable -data class MLflowEmbeddingsResponse( - val embeddings: List>, - val metadata: MLflowResponseMetadata +data class XefResponseChoice( + val index: Int, + val message: XefResponseMessage, + @SerialName("finish_reason") val finishReason: String? ) +@Serializable data class XefResponseMessage(val role: String, val content: String) + @Serializable -data class MLflowResponseMetadata( - val model: String, - @SerialName("route_type") val routeType: MLflowRouteType, - @SerialName("input_tokens") val inputTokens: Int? = null, - @SerialName("output_tokens") val outputTokens: Int? = null, - @SerialName("total_tokens") val totalTokens: Int? = null +data class XefResponseUsage( + @SerialName("prompt_tokens") val promptTokens: Int? = null, + @SerialName("completion_tokens") val completionTokens: Int? = null, + @SerialName("total_tokens") val totalTokens: Int? = null ) @Serializable -enum class MLflowRouteType { - @SerialName("llm/v1/completions") COMPLETIONS, - @SerialName("llm/v1/chat") CHAT, - @SerialName("llm/v1/embeddings") EMBEDDINGS -} +data class XefEmbeddingResponse( + @SerialName("object") val objectModel: String, + val data: List, + val model: String, + val usage: XefResponseUsage +) -@Serializable data class XefResponseMessage(val role: String, val content: String) +@Serializable +data class XefEmbedding( + @SerialName("object") val objectModel: String, + val index: Int, + val embedding: List +) @Serializable -data class XefResponseUsage( - @SerialName("prompt_tokens") val promptTokens: Int? = null, - @SerialName("completion_tokens") val completionTokens: Int? = null, - @SerialName("total_tokens") val totalTokens: Int? = null +data class MLflowChatRequest( + @Serializable(with = MLflowChatMessageSerializable::class) val messages: List, + val temperature: Double? = null, + @SerialName("candidate_count") val candidateCount: Int? = null, + val stop: List? = null, + @SerialName("max_tokens") val maxTokens: Int? = null ) @Serializable -data class XefEmbeddingResponse( - @SerialName("object") val objectModel: String, - val data: List, - val model: String, - val usage: XefResponseUsage +data class MLflowChatResponse( + val candidates: List, + val metadata: MLflowResponseMetadata ) @Serializable -data class XefEmbedding( - @SerialName("object") val objectModel: String, - val index: Int, - val embedding: List +data class MLflowChatCandidate( + val message: MLflowChatMessage, + val metadata: MLflowCandidateMetadata +) + +@Serializable data class MLflowChatMessage(val role: String, val content: String) + +@Serializable +data class MLflowCandidateMetadata(@SerialName("finish_reason") val finishReason: String?) + +@Serializable +data class MLflowEmbeddingsRequest( + @Serializable(with = StringArraySerializable::class) val text: List +) + +@Serializable +data class MLflowEmbeddingsResponse( + val embeddings: List>, + val metadata: MLflowResponseMetadata ) +@Serializable +data class MLflowResponseMetadata( + val model: String, + @SerialName("route_type") val routeType: MLflowRouteType, + @SerialName("input_tokens") val inputTokens: Int? = null, + @SerialName("output_tokens") val outputTokens: Int? = null, + @SerialName("total_tokens") val totalTokens: Int? = null +) + +@Serializable +enum class MLflowRouteType { + @SerialName("llm/v1/completions") COMPLETIONS, + @SerialName("llm/v1/chat") CHAT, + @SerialName("llm/v1/embeddings") EMBEDDINGS +} + +private object XefChatMessageSerializable : + JsonTransformingSerializer>(ListSerializer(XefChatMessage.serializer())) { + override fun transformDeserialize(element: JsonElement): JsonElement = + if (element !is JsonArray) JsonArray(listOf(element)) else element +} + +private object MLflowChatMessageSerializable : + JsonTransformingSerializer>( + ListSerializer(MLflowChatMessage.serializer()) + ) { + override fun transformSerialize(element: JsonElement): JsonElement = + if (element is JsonArray) { + val jsonArray = element.jsonArray + if (jsonArray.size == 1) jsonArray[0] else jsonArray + } else element +} + private object StringArraySerializable : - JsonTransformingSerializer>(ListSerializer(String.serializer())) { - override fun transformDeserialize(element: JsonElement): JsonElement = - if (element !is JsonArray) JsonArray(listOf(element)) else element - - override fun transformSerialize(element: JsonElement): JsonElement = - if (element is JsonArray) { - val jsonArray = element.jsonArray - if (jsonArray.size == 1) jsonArray[0] else jsonArray - } else element + JsonTransformingSerializer>(ListSerializer(String.serializer())) { + override fun transformDeserialize(element: JsonElement): JsonElement = + if (element !is JsonArray) JsonArray(listOf(element)) else element + + override fun transformSerialize(element: JsonElement): JsonElement = + if (element is JsonArray) { + val jsonArray = element.jsonArray + if (jsonArray.size == 1) jsonArray[0] else jsonArray + } else element } +fun XefChatRequest.toMLflow(): MLflowChatRequest = + MLflowChatRequest( + messages.map { MLflowChatMessage(it.role, it.content) }, + temperature, + n, + stop, + maxTokens + ) + fun XefEmbeddingsRequest.toMLflow(): MLflowEmbeddingsRequest = MLflowEmbeddingsRequest(input) +fun MLflowChatResponse.toXef(): XefChatResponse = + XefChatResponse( + UUID.generateUUID().toString(), + "chat.completion", + getTimeMillis(), + metadata.model, + candidates.mapIndexed { index, candidate -> + XefResponseChoice( + index, + XefResponseMessage(candidate.message.role, candidate.message.content), + candidate.metadata.finishReason + ) + }, + metadata.toXef() + ) + fun MLflowEmbeddingsResponse.toXef(): XefEmbeddingResponse = - XefEmbeddingResponse( - "list", - embeddings.mapIndexed { index, list -> XefEmbedding("embedding", index, list) }, - metadata.model, - metadata.toXef().copy(completionTokens = null) - ) + XefEmbeddingResponse( + "list", + embeddings.mapIndexed { index, list -> XefEmbedding("embedding", index, list) }, + metadata.model, + metadata.toXef().copy(completionTokens = null) + ) private fun MLflowResponseMetadata.toXef(): XefResponseUsage = - XefResponseUsage(inputTokens, outputTokens, totalTokens) \ No newline at end of file + XefResponseUsage(inputTokens, outputTokens, totalTokens) diff --git a/server/src/main/kotlin/com/xebia/functional/xef/server/http/routes/AIRoutes.kt b/server/src/main/kotlin/com/xebia/functional/xef/server/http/routes/AIRoutes.kt index 407c12430..15d8f6dbc 100644 --- a/server/src/main/kotlin/com/xebia/functional/xef/server/http/routes/AIRoutes.kt +++ b/server/src/main/kotlin/com/xebia/functional/xef/server/http/routes/AIRoutes.kt @@ -4,7 +4,6 @@ import com.aallam.openai.api.BetaOpenAI import com.xebia.functional.xef.server.models.Token import com.xebia.functional.xef.server.models.exceptions.XefExceptions import io.ktor.client.* -import io.ktor.client.call.* import io.ktor.client.request.* import io.ktor.client.statement.* import io.ktor.http.* @@ -20,27 +19,12 @@ import kotlinx.serialization.json.JsonObject import kotlinx.serialization.json.boolean import kotlinx.serialization.json.jsonPrimitive -enum class Provider { - OPENAI, - GPT4ALL, - GCP -} - -fun String.toProvider(): Provider? = - when (this) { - "openai" -> Provider.OPENAI - "gpt4all" -> Provider.GPT4ALL - "gcp" -> Provider.GCP - else -> Provider.OPENAI - } - @OptIn(BetaOpenAI::class) fun Routing.aiRoutes(client: HttpClient) { val openAiUrl = "https://api.openai.com/v1" authenticate("auth-bearer") { post("/chat/completions") { - val token = call.getToken() val byteArrayBody = call.receiveChannel().toByteArray() val body = byteArrayBody.toString(Charsets.UTF_8) val data = Json.decodeFromString(body) @@ -48,26 +32,24 @@ fun Routing.aiRoutes(client: HttpClient) { val isStream = data["stream"]?.jsonPrimitive?.boolean ?: false if (!isStream) { - client.makeRequest(call, "$openAiUrl/chat/completions", byteArrayBody, token) + client.makeRequest(call, "$openAiUrl/chat/completions", byteArrayBody) } else { - client.makeStreaming(call, "$openAiUrl/chat/completions", byteArrayBody, token) + client.makeStreaming(call, "$openAiUrl/chat/completions", byteArrayBody) } } post("/embeddings") { - val token = call.getToken() val context = call.receiveChannel().toByteArray() - client.makeRequest(call, "$openAiUrl/embeddings", context, token) + client.makeRequest(call, "$openAiUrl/embeddings", context) } } } -private suspend fun HttpClient.makeRequest( - call: ApplicationCall, - url: String, - body: ByteArray, - token: Token -) { +private val conflictingRequestHeaders = + listOf("Host", "Content-Type", "Content-Length", "Accept", "Accept-Encoding") +private val conflictingResponseHeaders = listOf("Content-Length") + +private suspend fun HttpClient.makeRequest(call: ApplicationCall, url: String, body: ByteArray) { val response = this.request(url) { headers.copyFrom(call.request.headers) @@ -75,43 +57,35 @@ private suspend fun HttpClient.makeRequest( method = HttpMethod.Post setBody(body) } - call.response.headers.copyFrom(response.headers, "Content-Length") - call.respond(response.status, response.readBytes()) + call.response.headers.copyFrom(response.headers) + call.respond(response.status, response.bodyAsText()) } -private suspend fun HttpClient.makeStreaming( - call: ApplicationCall, - url: String, - body: ByteArray, - token: Token -) { +private suspend fun HttpClient.makeStreaming(call: ApplicationCall, url: String, body: ByteArray) { this.preparePost(url) { headers.copyFrom(call.request.headers) method = HttpMethod.Post setBody(body) } .execute { httpResponse -> - call.response.headers.copyFrom(httpResponse.headers, "Content-Length") + call.response.headers.copyFrom(httpResponse.headers) call.respondOutputStream { httpResponse.bodyAsChannel().copyTo(this@respondOutputStream) } } } -private fun ResponseHeaders.copyFrom(headers: Headers, vararg filterOut: String) = +private fun ResponseHeaders.copyFrom(headers: Headers) = headers .entries() .filter { (key, _) -> !HttpHeaders.isUnsafe(key) } // setting unsafe headers results in exception - .filterNot { (key, _) -> filterOut.any { it.equals(key, true) } } + .filterNot { (key, _) -> conflictingResponseHeaders.any { it.equals(key, true) } } .forEach { (key, values) -> values.forEach { value -> this.appendIfAbsent(key, value) } } internal fun HeadersBuilder.copyFrom(headers: Headers) = headers - .filter { key, value -> !key.equals("HOST", ignoreCase = true) } - .forEach { key, values -> appendAll(key, values) } - -private fun ApplicationCall.getProvider(): Provider = - request.headers["xef-provider"]?.toProvider() ?: Provider.OPENAI + .filter { key, _ -> !conflictingRequestHeaders.any { it.equals(key, true) } } + .forEach { key, values -> appendMissing(key, values) } fun ApplicationCall.getToken(): Token = principal()?.name?.let { Token(it) } diff --git a/server/src/main/kotlin/com/xebia/functional/xef/server/services/LocalVectorStoreService.kt b/server/src/main/kotlin/com/xebia/functional/xef/server/services/LocalVectorStoreService.kt index 06364727b..9a1522122 100644 --- a/server/src/main/kotlin/com/xebia/functional/xef/server/services/LocalVectorStoreService.kt +++ b/server/src/main/kotlin/com/xebia/functional/xef/server/services/LocalVectorStoreService.kt @@ -1,7 +1,6 @@ package com.xebia.functional.xef.server.services import com.xebia.functional.xef.conversation.llm.openai.OpenAI -import com.xebia.functional.xef.server.http.routes.Provider import com.xebia.functional.xef.store.LocalVectorStore import com.xebia.functional.xef.store.VectorStore diff --git a/server/src/main/kotlin/com/xebia/functional/xef/server/services/PostgresVectorStoreService.kt b/server/src/main/kotlin/com/xebia/functional/xef/server/services/PostgresVectorStoreService.kt index 665464834..0dc225adb 100644 --- a/server/src/main/kotlin/com/xebia/functional/xef/server/services/PostgresVectorStoreService.kt +++ b/server/src/main/kotlin/com/xebia/functional/xef/server/services/PostgresVectorStoreService.kt @@ -2,7 +2,6 @@ package com.xebia.functional.xef.server.services import com.xebia.functional.xef.conversation.llm.openai.OpenAI import com.xebia.functional.xef.llm.models.embeddings.RequestConfig -import com.xebia.functional.xef.server.http.routes.Provider import com.xebia.functional.xef.store.PGVectorStore import com.xebia.functional.xef.store.VectorStore import com.xebia.functional.xef.store.postgresql.PGDistanceStrategy @@ -53,7 +52,6 @@ class PostgresVectorStoreService( val embeddings = when (provider) { Provider.OPENAI -> OpenAI(token).DEFAULT_EMBEDDING - else -> OpenAI(token).DEFAULT_EMBEDDING } return PGVectorStore( diff --git a/server/src/main/kotlin/com/xebia/functional/xef/server/services/VectorStoreService.kt b/server/src/main/kotlin/com/xebia/functional/xef/server/services/VectorStoreService.kt index fad7548d8..4cfa39c43 100644 --- a/server/src/main/kotlin/com/xebia/functional/xef/server/services/VectorStoreService.kt +++ b/server/src/main/kotlin/com/xebia/functional/xef/server/services/VectorStoreService.kt @@ -5,7 +5,6 @@ import com.typesafe.config.ConfigFactory import com.xebia.functional.xef.server.db.VectorStoreConfig import com.xebia.functional.xef.server.db.local.LocalVectorStoreConfig import com.xebia.functional.xef.server.db.psql.PSQLVectorStoreConfig -import com.xebia.functional.xef.server.http.routes.Provider import com.xebia.functional.xef.store.VectorStore import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.withContext @@ -18,6 +17,10 @@ enum class XefVectorStoreType { LOCAL } +enum class Provider { + OPENAI +} + abstract class VectorStoreService { abstract fun getVectorStore( provider: Provider = Provider.OPENAI,