Skip to content

Commit

Permalink
Add ZIO streaming support for chat completions (#127)
Browse files Browse the repository at this point in the history
  • Loading branch information
DybekK authored Nov 7, 2023
1 parent dfd0826 commit 18a6f2f
Show file tree
Hide file tree
Showing 5 changed files with 184 additions and 3 deletions.
12 changes: 9 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ object Main extends App {
*/
}
```
#### Currently only two backend implementations are available:
#### Available backend implementations:
* `OpenAISyncBackend` which uses identity monad `Id[A]` as an effect `F[A]` and throws `OpenAIException`
* `OpenAI` which provides raw sttp `Request`s and wraps `Response`s into `Either[OpenAIException, A]`

Expand Down Expand Up @@ -140,14 +140,20 @@ object Main extends IOApp {
```

#### Create completion with streaming:
The Chat Completion API features streaming support via server-sent events. Currently, we only support streaming using `Fs2`.
The Chat Completion API features streaming support via server-sent events. Currently, we only support streaming using `Fs2` and `ZIO`

Add the following import:
To use `Fs2` add the following import:

```scala
import sttp.openai.streaming.fs2._
```

To use `ZIO` add the following import:

```scala
import sttp.openai.streaming.zio._
```

Example below uses `HttpClientFs2Backend` as a backend.

```scala mdoc:compile-only
Expand Down
11 changes: 11 additions & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ lazy val root = (project in file("."))

lazy val allAgregates = core.projectRefs ++
fs2.projectRefs ++
zio.projectRefs ++
docs.projectRefs

lazy val core = (projectMatrix in file("core"))
Expand All @@ -40,6 +41,16 @@ lazy val fs2 = (projectMatrix in file("streaming/fs2"))
)
.dependsOn(core % "compile->compile;test->test")

lazy val zio = (projectMatrix in file("streaming/zio"))
.jvmPlatform(
scalaVersions = scala2 ++ scala3
)
.settings(commonSettings)
.settings(
libraryDependencies += Libraries.sttpClientZio
)
.dependsOn(core % "compile->compile;test->test")

val compileDocs: TaskKey[Unit] = taskKey[Unit]("Compiles docs module throwing away its output")
compileDocs := {
(docs.jvm(scala2.head) / mdoc).toTask(" --out target/sttp-openai-docs").value
Expand Down
2 changes: 2 additions & 0 deletions project/Dependencies.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ object Dependencies {
"org.typelevel" %% "cats-effect-testing-scalatest" % V.scalaTestCats % Test
)

val sttpClientZio = "com.softwaremill.sttp.client4" %% "zio" % V.sttpClient

val uPickle = "com.lihaoyi" %% "upickle" % V.uPickle

}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package sttp.openai.streaming

import sttp.client4.StreamRequest
import sttp.model.sse.ServerSentEvent
import sttp.openai.OpenAI
import sttp.openai.OpenAIExceptions.OpenAIException
import sttp.openai.json.SttpUpickleApiExtension.deserializeJsonSnake
import sttp.openai.requests.completions.chat.ChatChunkRequestResponseData.ChatChunkResponse
import sttp.openai.requests.completions.chat.ChatRequestBody.ChatBody
import _root_.zio.stream._
import _root_.zio.ZIO
import sttp.capabilities.zio.ZioStreams
import sttp.client4.impl.zio.ZioServerSentEvents

package object zio {
import ChatChunkResponse.DoneEventMessage

implicit class extension(val client: OpenAI) {

/** Creates and streams a model response as chunk objects for the given chat conversation defined in chatBody.
*
* [[https://platform.openai.com/docs/api-reference/chat/create]]
*
* @param chatBody
* Chat request body.
*/
def createStreamedChatCompletion(
chatBody: ChatBody
): StreamRequest[Either[OpenAIException, Stream[Throwable, ChatChunkResponse]], ZioStreams] =
client
.createChatCompletion(ZioStreams, chatBody)
.mapResponse(mapEventToResponse)
}

private def mapEventToResponse(
response: Either[OpenAIException, Stream[Throwable, Byte]]
): Either[OpenAIException, Stream[Throwable, ChatChunkResponse]] =
response.map(
_.viaFunction(ZioServerSentEvents.parse)
.viaFunction(deserializeEvent)
)

private def deserializeEvent: ZioStreams.Pipe[ServerSentEvent, ChatChunkResponse] =
_.collectZIO {
case ServerSentEvent(Some(data), _, _, _) if data != DoneEventMessage =>
ZIO.fromEither(deserializeJsonSnake[ChatChunkResponse].apply(data))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
package sttp.openai.streaming.zio

import org.scalatest.EitherValues
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
import sttp.client4.DeserializationException
import sttp.client4.httpclient.zio.HttpClientZioBackend
import sttp.client4.testing.RawStream
import sttp.model.sse.ServerSentEvent
import sttp.openai.fixtures.ErrorFixture
import sttp.openai.json.SnakePickle._
import sttp.openai.requests.completions.chat.ChatChunkRequestResponseData.ChatChunkResponse
import sttp.openai.requests.completions.chat.ChatChunkRequestResponseData.ChatChunkResponse.DoneEventMessage
import sttp.openai.requests.completions.chat.ChatRequestBody.{ChatBody, ChatCompletionModel}
import sttp.openai.utils.JsonUtils.compactJson
import sttp.openai.{OpenAI, OpenAIExceptions}
import zio._
import zio.stream._

class ZioClientSpec extends AnyFlatSpec with Matchers with EitherValues {
private val runtime: Runtime[Any] = Runtime.default

for ((statusCode, expectedError) <- ErrorFixture.testData)
s"Service response with status code: $statusCode" should s"return properly deserialized ${expectedError.getClass.getSimpleName}" in {
// given
val zioBackendStub = HttpClientZioBackend.stub.whenAnyRequest.thenRespondWithCode(statusCode, ErrorFixture.errorResponse)
val client = new OpenAI("test-token")

val givenRequest = ChatBody(
model = ChatCompletionModel.GPT35Turbo,
messages = Seq.empty
)

// when
val caughtEffect: ZIO[Any, Throwable, OpenAIExceptions.OpenAIException] = client
.createStreamedChatCompletion(givenRequest)
.send(zioBackendStub)
.map(_.body.left.value)

val caught = unsafeRun(caughtEffect)

// then
caught.getClass shouldBe expectedError.getClass
caught.message shouldBe expectedError.message
caught.cause shouldBe expectedError.cause
caught.code shouldBe expectedError.code
caught.param shouldBe expectedError.param
caught.`type` shouldBe expectedError.`type`
}

"Creating chat completions with failed stream due to invalid deserialization" should "return properly deserialized error" in {
// given
val invalidJson = Some("invalid json")
val invalidEvent = ServerSentEvent(invalidJson)

val streamedResponse = ZStream
.succeed(invalidEvent.toString)
.via(ZPipeline.utf8Encode)

val zioBackendStub = HttpClientZioBackend.stub.whenAnyRequest.thenRespond(RawStream(streamedResponse))
val client = new OpenAI(authToken = "test-token")

val givenRequest = ChatBody(
model = ChatCompletionModel.GPT35Turbo,
messages = Seq.empty
)

// when
val responseEffect = client
.createStreamedChatCompletion(givenRequest)
.send(zioBackendStub)
.flatMap(_.body.value.runDrain)

val response = unsafeRun(responseEffect.either)

// then
response shouldBe a[Left[DeserializationException[_], _]]
}

"Creating chat completions with successful response" should "return properly deserialized list of chunks" in {
// given
val chatChunks = Seq.fill(3)(sttp.openai.fixtures.ChatChunkFixture.jsonResponse).map(compactJson)
val events = chatChunks.map(data => ServerSentEvent(Some(data))) :+ ServerSentEvent(Some(DoneEventMessage))
val delimiter = "\n\n"
val streamedResponse = ZStream
.from(events)
.map(_.toString + delimiter)
.via(ZPipeline.utf8Encode)

val zioBackendStub = HttpClientZioBackend.stub.whenAnyRequest.thenRespond(RawStream(streamedResponse))
val client = new OpenAI(authToken = "test-token")

val givenRequest = ChatBody(
model = ChatCompletionModel.GPT35Turbo,
messages = Seq.empty
)

// when
val responseEffect = client
.createStreamedChatCompletion(givenRequest)
.send(zioBackendStub)
.map(_.body.value)
.flatMap(_.runCollect)

val response = unsafeRun(responseEffect)

// then
val expectedResponse = chatChunks.map(read[ChatChunkResponse](_))
response.toList shouldBe expectedResponse
}

private def unsafeRun[E, A](zio: ZIO[Any, E, A]): A =
Unsafe.unsafe(implicit unsafe => runtime.unsafe.run(zio).getOrThrowFiberFailure())
}

0 comments on commit 18a6f2f

Please sign in to comment.