From 18a6f2f3d1671def9ea919b5b8b3d58b1d2caa21 Mon Sep 17 00:00:00 2001 From: Mateusz Dybowski Date: Tue, 7 Nov 2023 13:06:25 +0100 Subject: [PATCH] Add ZIO streaming support for chat completions (#127) --- README.md | 12 +- build.sbt | 11 ++ project/Dependencies.scala | 2 + .../sttp/openai/streaming/zio/package.scala | 48 ++++++++ .../openai/streaming/zio/ZioClientSpec.scala | 114 ++++++++++++++++++ 5 files changed, 184 insertions(+), 3 deletions(-) create mode 100644 streaming/zio/src/main/scala/sttp/openai/streaming/zio/package.scala create mode 100644 streaming/zio/src/test/scala/sttp/openai/streaming/zio/ZioClientSpec.scala diff --git a/README.md b/README.md index 1a32e227..08ad3483 100644 --- a/README.md +++ b/README.md @@ -74,7 +74,7 @@ object Main extends App { */ } ``` -#### Currently only two backend implementations are available: +#### Available backend implementations: * `OpenAISyncBackend` which uses identity monad `Id[A]` as an effect `F[A]` and throws `OpenAIException` * `OpenAI` which provides raw sttp `Request`s and wraps `Response`s into `Either[OpenAIException, A]` @@ -140,14 +140,20 @@ object Main extends IOApp { ``` #### Create completion with streaming: -The Chat Completion API features streaming support via server-sent events. Currently, we only support streaming using `Fs2`. +The Chat Completion API features streaming support via server-sent events. Currently, we only support streaming using `Fs2` and `ZIO` -Add the following import: +To use `Fs2` add the following import: ```scala import sttp.openai.streaming.fs2._ ``` +To use `ZIO` add the following import: + +```scala +import sttp.openai.streaming.zio._ +``` + Example below uses `HttpClientFs2Backend` as a backend. ```scala mdoc:compile-only diff --git a/build.sbt b/build.sbt index 7e1b0530..5ebf7295 100644 --- a/build.sbt +++ b/build.sbt @@ -19,6 +19,7 @@ lazy val root = (project in file(".")) lazy val allAgregates = core.projectRefs ++ fs2.projectRefs ++ + zio.projectRefs ++ docs.projectRefs lazy val core = (projectMatrix in file("core")) @@ -40,6 +41,16 @@ lazy val fs2 = (projectMatrix in file("streaming/fs2")) ) .dependsOn(core % "compile->compile;test->test") +lazy val zio = (projectMatrix in file("streaming/zio")) + .jvmPlatform( + scalaVersions = scala2 ++ scala3 + ) + .settings(commonSettings) + .settings( + libraryDependencies += Libraries.sttpClientZio + ) + .dependsOn(core % "compile->compile;test->test") + val compileDocs: TaskKey[Unit] = taskKey[Unit]("Compiles docs module throwing away its output") compileDocs := { (docs.jvm(scala2.head) / mdoc).toTask(" --out target/sttp-openai-docs").value diff --git a/project/Dependencies.scala b/project/Dependencies.scala index 1cba935a..0d2c481f 100644 --- a/project/Dependencies.scala +++ b/project/Dependencies.scala @@ -24,6 +24,8 @@ object Dependencies { "org.typelevel" %% "cats-effect-testing-scalatest" % V.scalaTestCats % Test ) + val sttpClientZio = "com.softwaremill.sttp.client4" %% "zio" % V.sttpClient + val uPickle = "com.lihaoyi" %% "upickle" % V.uPickle } diff --git a/streaming/zio/src/main/scala/sttp/openai/streaming/zio/package.scala b/streaming/zio/src/main/scala/sttp/openai/streaming/zio/package.scala new file mode 100644 index 00000000..4f333e6a --- /dev/null +++ b/streaming/zio/src/main/scala/sttp/openai/streaming/zio/package.scala @@ -0,0 +1,48 @@ +package sttp.openai.streaming + +import sttp.client4.StreamRequest +import sttp.model.sse.ServerSentEvent +import sttp.openai.OpenAI +import sttp.openai.OpenAIExceptions.OpenAIException +import sttp.openai.json.SttpUpickleApiExtension.deserializeJsonSnake +import sttp.openai.requests.completions.chat.ChatChunkRequestResponseData.ChatChunkResponse +import sttp.openai.requests.completions.chat.ChatRequestBody.ChatBody +import _root_.zio.stream._ +import _root_.zio.ZIO +import sttp.capabilities.zio.ZioStreams +import sttp.client4.impl.zio.ZioServerSentEvents + +package object zio { + import ChatChunkResponse.DoneEventMessage + + implicit class extension(val client: OpenAI) { + + /** Creates and streams a model response as chunk objects for the given chat conversation defined in chatBody. + * + * [[https://platform.openai.com/docs/api-reference/chat/create]] + * + * @param chatBody + * Chat request body. + */ + def createStreamedChatCompletion( + chatBody: ChatBody + ): StreamRequest[Either[OpenAIException, Stream[Throwable, ChatChunkResponse]], ZioStreams] = + client + .createChatCompletion(ZioStreams, chatBody) + .mapResponse(mapEventToResponse) + } + + private def mapEventToResponse( + response: Either[OpenAIException, Stream[Throwable, Byte]] + ): Either[OpenAIException, Stream[Throwable, ChatChunkResponse]] = + response.map( + _.viaFunction(ZioServerSentEvents.parse) + .viaFunction(deserializeEvent) + ) + + private def deserializeEvent: ZioStreams.Pipe[ServerSentEvent, ChatChunkResponse] = + _.collectZIO { + case ServerSentEvent(Some(data), _, _, _) if data != DoneEventMessage => + ZIO.fromEither(deserializeJsonSnake[ChatChunkResponse].apply(data)) + } +} diff --git a/streaming/zio/src/test/scala/sttp/openai/streaming/zio/ZioClientSpec.scala b/streaming/zio/src/test/scala/sttp/openai/streaming/zio/ZioClientSpec.scala new file mode 100644 index 00000000..891e94f1 --- /dev/null +++ b/streaming/zio/src/test/scala/sttp/openai/streaming/zio/ZioClientSpec.scala @@ -0,0 +1,114 @@ +package sttp.openai.streaming.zio + +import org.scalatest.EitherValues +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers +import sttp.client4.DeserializationException +import sttp.client4.httpclient.zio.HttpClientZioBackend +import sttp.client4.testing.RawStream +import sttp.model.sse.ServerSentEvent +import sttp.openai.fixtures.ErrorFixture +import sttp.openai.json.SnakePickle._ +import sttp.openai.requests.completions.chat.ChatChunkRequestResponseData.ChatChunkResponse +import sttp.openai.requests.completions.chat.ChatChunkRequestResponseData.ChatChunkResponse.DoneEventMessage +import sttp.openai.requests.completions.chat.ChatRequestBody.{ChatBody, ChatCompletionModel} +import sttp.openai.utils.JsonUtils.compactJson +import sttp.openai.{OpenAI, OpenAIExceptions} +import zio._ +import zio.stream._ + +class ZioClientSpec extends AnyFlatSpec with Matchers with EitherValues { + private val runtime: Runtime[Any] = Runtime.default + + for ((statusCode, expectedError) <- ErrorFixture.testData) + s"Service response with status code: $statusCode" should s"return properly deserialized ${expectedError.getClass.getSimpleName}" in { + // given + val zioBackendStub = HttpClientZioBackend.stub.whenAnyRequest.thenRespondWithCode(statusCode, ErrorFixture.errorResponse) + val client = new OpenAI("test-token") + + val givenRequest = ChatBody( + model = ChatCompletionModel.GPT35Turbo, + messages = Seq.empty + ) + + // when + val caughtEffect: ZIO[Any, Throwable, OpenAIExceptions.OpenAIException] = client + .createStreamedChatCompletion(givenRequest) + .send(zioBackendStub) + .map(_.body.left.value) + + val caught = unsafeRun(caughtEffect) + + // then + caught.getClass shouldBe expectedError.getClass + caught.message shouldBe expectedError.message + caught.cause shouldBe expectedError.cause + caught.code shouldBe expectedError.code + caught.param shouldBe expectedError.param + caught.`type` shouldBe expectedError.`type` + } + + "Creating chat completions with failed stream due to invalid deserialization" should "return properly deserialized error" in { + // given + val invalidJson = Some("invalid json") + val invalidEvent = ServerSentEvent(invalidJson) + + val streamedResponse = ZStream + .succeed(invalidEvent.toString) + .via(ZPipeline.utf8Encode) + + val zioBackendStub = HttpClientZioBackend.stub.whenAnyRequest.thenRespond(RawStream(streamedResponse)) + val client = new OpenAI(authToken = "test-token") + + val givenRequest = ChatBody( + model = ChatCompletionModel.GPT35Turbo, + messages = Seq.empty + ) + + // when + val responseEffect = client + .createStreamedChatCompletion(givenRequest) + .send(zioBackendStub) + .flatMap(_.body.value.runDrain) + + val response = unsafeRun(responseEffect.either) + + // then + response shouldBe a[Left[DeserializationException[_], _]] + } + + "Creating chat completions with successful response" should "return properly deserialized list of chunks" in { + // given + val chatChunks = Seq.fill(3)(sttp.openai.fixtures.ChatChunkFixture.jsonResponse).map(compactJson) + val events = chatChunks.map(data => ServerSentEvent(Some(data))) :+ ServerSentEvent(Some(DoneEventMessage)) + val delimiter = "\n\n" + val streamedResponse = ZStream + .from(events) + .map(_.toString + delimiter) + .via(ZPipeline.utf8Encode) + + val zioBackendStub = HttpClientZioBackend.stub.whenAnyRequest.thenRespond(RawStream(streamedResponse)) + val client = new OpenAI(authToken = "test-token") + + val givenRequest = ChatBody( + model = ChatCompletionModel.GPT35Turbo, + messages = Seq.empty + ) + + // when + val responseEffect = client + .createStreamedChatCompletion(givenRequest) + .send(zioBackendStub) + .map(_.body.value) + .flatMap(_.runCollect) + + val response = unsafeRun(responseEffect) + + // then + val expectedResponse = chatChunks.map(read[ChatChunkResponse](_)) + response.toList shouldBe expectedResponse + } + + private def unsafeRun[E, A](zio: ZIO[Any, E, A]): A = + Unsafe.unsafe(implicit unsafe => runtime.unsafe.run(zio).getOrThrowFiberFailure()) +}