From 39cd65b06d788a7d9b05f2d83fc163899bd35e56 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fede=20Fern=C3=A1ndez?= <720923+fedefernandez@users.noreply.github.com> Date: Mon, 30 Oct 2023 18:05:19 +0100 Subject: [PATCH] Allow custom mappings for different models --- .../xef/server/http/client/ModelUriAdapter.kt | 79 +++++++++++++++++++ .../http/client/ModelUriAdapterBuilder.kt | 17 ++++ .../xef/server/http/client/OpenAIPathType.kt | 15 ++++ .../xef/server/http/routes/AIRoutes.kt | 7 +- 4 files changed, 115 insertions(+), 3 deletions(-) create mode 100644 server/src/main/kotlin/com/xebia/functional/xef/server/http/client/ModelUriAdapter.kt create mode 100644 server/src/main/kotlin/com/xebia/functional/xef/server/http/client/ModelUriAdapterBuilder.kt create mode 100644 server/src/main/kotlin/com/xebia/functional/xef/server/http/client/OpenAIPathType.kt diff --git a/server/src/main/kotlin/com/xebia/functional/xef/server/http/client/ModelUriAdapter.kt b/server/src/main/kotlin/com/xebia/functional/xef/server/http/client/ModelUriAdapter.kt new file mode 100644 index 000000000..628560be9 --- /dev/null +++ b/server/src/main/kotlin/com/xebia/functional/xef/server/http/client/ModelUriAdapter.kt @@ -0,0 +1,79 @@ +package com.xebia.functional.xef.server.http.client + +import io.github.oshai.kotlinlogging.KotlinLogging +import io.ktor.client.* +import io.ktor.client.plugins.* +import io.ktor.client.request.* +import io.ktor.http.* +import io.ktor.http.content.* +import io.ktor.util.* +import io.ktor.util.pipeline.* +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.jsonPrimitive + +class ModelUriAdapter +internal constructor(private val urlMap: Map>) { + + val logger = KotlinLogging.logger {} + + fun isDefined(path: OpenAIPathType): Boolean = urlMap.containsKey(path) + + fun findPath(path: OpenAIPathType, model: String): String? = urlMap[path]?.get(model) + + companion object : HttpClientPlugin { + + override val key: AttributeKey = AttributeKey("ModelAuthAdapter") + + override fun prepare(block: ModelUriAdapterBuilder.() -> Unit): ModelUriAdapter = + ModelUriAdapterBuilder().apply(block).build() + + override fun install(plugin: ModelUriAdapter, scope: HttpClient) { + installModelAuthAdapter(plugin, scope) + } + + private fun readModelFromRequest(originalRequest: OutgoingContent.ByteArrayContent?): String? { + val requestBody = originalRequest?.bytes()?.toString(Charsets.UTF_8) + val json = requestBody?.let { Json.decodeFromString(it) } + return json?.get("model")?.jsonPrimitive?.content + } + + private fun installModelAuthAdapter(plugin: ModelUriAdapter, scope: HttpClient) { + val adaptAuthRequestPhase = PipelinePhase("ModelAuthAdaptRequest") + scope.sendPipeline.insertPhaseAfter(HttpSendPipeline.State, adaptAuthRequestPhase) + scope.sendPipeline.intercept(adaptAuthRequestPhase) { content -> + val originalPath = OpenAIPathType.from(context.url.encodedPath) ?: return@intercept + if (plugin.isDefined(originalPath)) { + val originalRequest = content as? OutgoingContent.ByteArrayContent + if (originalRequest == null) { + plugin.logger.warn { + """ + |Can't adapt the model auth. + |The body type is: ${content::class}, with Content-Type: ${context.contentType()}. + | + |If you expect serialized body, please check that you have installed the corresponding + |plugin(like `ContentNegotiation`) and set `Content-Type` header.""" + .trimMargin() + } + return@intercept + } + val model = readModelFromRequest(originalRequest) + val newURL = model?.let { plugin.findPath(originalPath, it) } + if (newURL == null) { + plugin.logger.info { + "Model auth didn't found a new url for path $originalPath and model $model" + } + } else { + val baseBuilder = URLBuilder(newURL).build() + context.url.set( + scheme = baseBuilder.protocol.name, + host = baseBuilder.host, + port = baseBuilder.port, + path = baseBuilder.encodedPath + ) + } + } + } + } + } +} diff --git a/server/src/main/kotlin/com/xebia/functional/xef/server/http/client/ModelUriAdapterBuilder.kt b/server/src/main/kotlin/com/xebia/functional/xef/server/http/client/ModelUriAdapterBuilder.kt new file mode 100644 index 000000000..a2bda8039 --- /dev/null +++ b/server/src/main/kotlin/com/xebia/functional/xef/server/http/client/ModelUriAdapterBuilder.kt @@ -0,0 +1,17 @@ +package com.xebia.functional.xef.server.http.client + +class ModelUriAdapterBuilder { + + private var pathMap: Map> = LinkedHashMap() + + fun setPathMap(pathMap: Map>) { + this.pathMap = pathMap + } + + fun addToPath(path: OpenAIPathType, vararg modelUriPaths: Pair) { + val newPathTypeMap = mapOf(*modelUriPaths.map { Pair(it.first, it.second) }.toTypedArray()) + this.pathMap += mapOf(path to newPathTypeMap) + } + + internal fun build(): ModelUriAdapter = ModelUriAdapter(pathMap) +} diff --git a/server/src/main/kotlin/com/xebia/functional/xef/server/http/client/OpenAIPathType.kt b/server/src/main/kotlin/com/xebia/functional/xef/server/http/client/OpenAIPathType.kt new file mode 100644 index 000000000..786fc2b80 --- /dev/null +++ b/server/src/main/kotlin/com/xebia/functional/xef/server/http/client/OpenAIPathType.kt @@ -0,0 +1,15 @@ +package com.xebia.functional.xef.server.http.client + +enum class OpenAIPathType(val value: String) { + CHAT("/v1/chat/completions"), + EMBEDDINGS("/v1/embeddings"), + FINE_TUNING("/v1/fine_tuning/jobs"), + FILES("/v1/files"), + IMAGES("/v1/images/generations"), + MODELS("/v1/models"), + MODERATION("/v1/moderations"); + + companion object { + fun from(v: String): OpenAIPathType? = entries.find { it.value == v } + } +} diff --git a/server/src/main/kotlin/com/xebia/functional/xef/server/http/routes/AIRoutes.kt b/server/src/main/kotlin/com/xebia/functional/xef/server/http/routes/AIRoutes.kt index 99bb38f43..407c12430 100644 --- a/server/src/main/kotlin/com/xebia/functional/xef/server/http/routes/AIRoutes.kt +++ b/server/src/main/kotlin/com/xebia/functional/xef/server/http/routes/AIRoutes.kt @@ -75,7 +75,7 @@ private suspend fun HttpClient.makeRequest( method = HttpMethod.Post setBody(body) } - call.response.headers.copyFrom(response.headers) + call.response.headers.copyFrom(response.headers, "Content-Length") call.respond(response.status, response.readBytes()) } @@ -91,17 +91,18 @@ private suspend fun HttpClient.makeStreaming( setBody(body) } .execute { httpResponse -> - call.response.headers.copyFrom(httpResponse.headers) + call.response.headers.copyFrom(httpResponse.headers, "Content-Length") call.respondOutputStream { httpResponse.bodyAsChannel().copyTo(this@respondOutputStream) } } } -private fun ResponseHeaders.copyFrom(headers: Headers) = +private fun ResponseHeaders.copyFrom(headers: Headers, vararg filterOut: String) = headers .entries() .filter { (key, _) -> !HttpHeaders.isUnsafe(key) } // setting unsafe headers results in exception + .filterNot { (key, _) -> filterOut.any { it.equals(key, true) } } .forEach { (key, values) -> values.forEach { value -> this.appendIfAbsent(key, value) } } internal fun HeadersBuilder.copyFrom(headers: Headers) =