Skip to content

Commit

Permalink
Add streaming support for chat completions (#124)
Browse files Browse the repository at this point in the history
  • Loading branch information
DybekK authored Nov 7, 2023
1 parent a0ad113 commit 91e7617
Show file tree
Hide file tree
Showing 18 changed files with 657 additions and 77 deletions.
106 changes: 95 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,15 @@ sttp openai is available for Scala 2.13 and Scala 3

OpenAI API Official Documentation https://platform.openai.com/docs/api-reference/completions

### Not yet implemented:
* Create chat completions SSE
* Create completions SSE
* List fine-tune events SSE

## Example

### To use ChatGPT

```scala mdoc:compile-only
import sttp.openai.OpenAISyncClient
import sttp.openai.requests.completions.chat.ChatRequestResponseData.ChatResponse
import sttp.openai.requests.completions.chat.ChatRequestBody.{ChatBody, ChatCompletionModel}
import sttp.openai.requests.completions.chat.{Message, Role}

import sttp.openai.requests.completions.chat.ChatRequestBody.{ChatBody, ChatCompletionModel, Message}
import sttp.openai.requests.completions.chat.Role

object Main extends App {
// Create an instance of OpenAISyncClient providing your API secret-key
Expand Down Expand Up @@ -92,11 +86,12 @@ or use backend of your choice.
```scala mdoc:compile-only
import cats.effect.{ExitCode, IO, IOApp}
import sttp.client4.httpclient.cats.HttpClientCatsBackend

import sttp.openai.OpenAI
import sttp.openai.OpenAIExceptions.OpenAIException
import sttp.openai._
import sttp.openai.requests.completions.chat.ChatRequestBody.{ChatBody, ChatCompletionModel}
import sttp.openai.requests.completions.chat.ChatRequestResponseData.ChatResponse
import sttp.openai.requests.completions.chat.{Message, Role}
import sttp.openai.requests.completions.chat.ChatRequestBody.{ChatBody, ChatCompletionModel, Message}
import sttp.openai.requests.completions.chat.Role

object Main extends IOApp {
override def run(args: List[String]): IO[ExitCode] = {
Expand Down Expand Up @@ -143,6 +138,95 @@ object Main extends IOApp {
*/
}
```

#### Create completion with streaming:
The Chat Completion API features streaming support via server-sent events. Currently, we only support streaming using `Fs2`.

Add the following import:

```scala
import sttp.openai.streaming.fs2._
```

Example below uses `HttpClientFs2Backend` as a backend.

```scala mdoc:compile-only
import cats.effect.{ExitCode, IO, IOApp}
import fs2.Stream
import sttp.client4.httpclient.fs2.HttpClientFs2Backend

import sttp.openai.OpenAI
import sttp.openai.streaming.fs2._
import sttp.openai.OpenAIExceptions.OpenAIException
import sttp.openai.requests.completions.chat.ChatChunkRequestResponseData.ChatChunkResponse
import sttp.openai.requests.completions.chat.ChatRequestBody.{ChatBody, ChatCompletionModel, Message}
import sttp.openai.requests.completions.chat.Role

object Main extends IOApp {
override def run(args: List[String]): IO[ExitCode] = {
val openAI: OpenAI = new OpenAI("your-secret-key")

val bodyMessages: Seq[Message] = Seq(
Message(
role = Role.User,
content = "Hello!"
)
)

val chatRequestBody: ChatBody = ChatBody(
model = ChatCompletionModel.GPT35Turbo,
messages = bodyMessages
)

HttpClientFs2Backend.resource[IO]().use { backend =>
val response: IO[Either[OpenAIException, Stream[IO, ChatChunkResponse]]] =
openAI
.createStreamedChatCompletion[IO](chatRequestBody)
.send(backend)
.map(_.body)

response
.flatMap {
case Left(exception) => IO.println(exception.getMessage)
case Right(stream) => stream.evalTap(IO.println).compile.drain
}
.as(ExitCode.Success)
}
}
/*
...
ChatChunkResponse(
"chatcmpl-8HEZFNDmu2AYW8jVvNKyRO4W4KcO8",
"chat.completion.chunk",
1699118265,
"gpt-3.5-turbo-0613",
List(
Choices(
Delta(None, Some("Hi"), None),
null,
0
)
)
)
...
ChatChunkResponse(
"chatcmpl-8HEZFNDmu2AYW8jVvNKyRO4W4KcO8",
"chat.completion.chunk",
1699118265,
"gpt-3.5-turbo-0613",
List(
Choices(
Delta(None, Some(" there"), None),
null,
0
)
)
)
...
*/
}
```

## Contributing

If you have a question, or hit a problem, feel free to post on our community https://softwaremill.community/c/open-source/
Expand Down
17 changes: 16 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@ lazy val commonSettings = commonSmlBuildSettings ++ ossPublishSettings ++ Seq(
lazy val root = (project in file("."))
.settings(commonSettings: _*)
.settings(publish / skip := true, name := "sttp-openai", scalaVersion := scala2.head)
.aggregate(core.projectRefs: _*)
.aggregate(allAgregates: _*)

lazy val allAgregates = core.projectRefs ++
fs2.projectRefs ++
docs.projectRefs

lazy val core = (projectMatrix in file("core"))
.jvmPlatform(
Expand All @@ -26,6 +30,16 @@ lazy val core = (projectMatrix in file("core"))
)
.settings(commonSettings: _*)

lazy val fs2 = (projectMatrix in file("streaming/fs2"))
.jvmPlatform(
scalaVersions = scala2 ++ scala3
)
.settings(commonSettings)
.settings(
libraryDependencies ++= Libraries.sttpClientFs2
)
.dependsOn(core % "compile->compile;test->test")

val compileDocs: TaskKey[Unit] = taskKey[Unit]("Compiles docs module throwing away its output")
compileDocs := {
(docs.jvm(scala2.head) / mdoc).toTask(" --out target/sttp-openai-docs").value
Expand All @@ -48,4 +62,5 @@ lazy val docs = (projectMatrix in file("generated-docs")) // important: it must
evictionErrorLevel := Level.Info
)
.dependsOn(core)
.dependsOn(fs2)
.jvmPlatform(scalaVersions = scala2)
18 changes: 17 additions & 1 deletion core/src/main/scala/sttp/openai/OpenAI.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package sttp.openai
import sttp.client4._
import sttp.model.Uri
import sttp.openai.OpenAIExceptions.OpenAIException
import sttp.openai.json.SttpUpickleApiExtension.{asJsonSnake, asStringEither, upickleBodySerializer}
import sttp.openai.json.SttpUpickleApiExtension.{asJsonSnake, asStreamSnake, asStringEither, upickleBodySerializer}
import sttp.openai.requests.completions.CompletionsRequestBody.CompletionsBody
import sttp.openai.requests.completions.CompletionsResponseData.CompletionsResponse
import sttp.openai.requests.completions.chat.ChatRequestBody.ChatBody
Expand Down Expand Up @@ -31,6 +31,7 @@ import sttp.openai.requests.audio.AudioResponseData.AudioResponse
import sttp.openai.requests.audio.transcriptions.TranscriptionConfig
import sttp.openai.requests.audio.translations.TranslationConfig
import sttp.openai.requests.audio.RecognitionModel
import sttp.capabilities.Streams

import java.io.File
import java.nio.file.Paths
Expand Down Expand Up @@ -239,6 +240,21 @@ class OpenAI(authToken: String) {
.body(chatBody)
.response(asJsonSnake[ChatResponse])

/** Creates and streams a model response as chunk objects for the given chat conversation defined in chatBody.
*
* [[https://platform.openai.com/docs/api-reference/chat/create]]
*
* @param s
* The streams implementation to use.
* @param chatBody
* Chat request body.
*/
def createChatCompletion[S](s: Streams[S], chatBody: ChatBody): StreamRequest[Either[OpenAIException, s.BinaryStream], S] =
openAIAuthRequest
.post(OpenAIUris.ChatCompletions)
.body(ChatBody.withStreaming(chatBody))
.response(asStreamSnake(s))

/** Returns a list of files that belong to the user's organization.
*
* [[https://platform.openai.com/docs/api-reference/files]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,19 @@ import sttp.model.StatusCode._
import sttp.model.ResponseMetadata
import sttp.openai.OpenAIExceptions.OpenAIException
import sttp.openai.OpenAIExceptions.OpenAIException._
import sttp.capabilities.Streams

/** An sttp upickle api extension that deserializes JSON with snake_case keys into case classes with fields corresponding to keys in
* camelCase and maps errors to OpenAIException subclasses.
*/
object SttpUpickleApiExtension extends SttpUpickleApi {
override val upickleApi: SnakePickle.type = SnakePickle

def asStreamSnake[S](s: Streams[S]): StreamResponseAs[Either[OpenAIException, s.BinaryStream], S] =
asStreamUnsafe(s).mapWithMetadata { (body, meta) =>
body.left.map(errorBody => httpToOpenAIError(HttpError(errorBody, meta.code)))
}

def asJsonSnake[B: upickleApi.Reader: IsOption]: ResponseAs[Either[OpenAIException, B]] =
asString.mapWithMetadata(deserializeRightWithMappedExceptions(deserializeJsonSnake)).showAsJson

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package sttp.openai.requests.completions.chat

import sttp.openai.json.SnakePickle

object ChatChunkRequestResponseData {

/** @param role
* The role of the author of this message.
* @param content
* The contents of the message.
* @param functionCall
* The name of the author of this message. May contain a-z, A-Z, 0-9, and underscores, with a maximum length of 64 characters.
*/
case class Delta(role: Option[Role] = None, content: Option[String] = None, functionCall: Option[FunctionCall] = None)

object Delta {
implicit val deltaR: SnakePickle.Reader[Delta] = SnakePickle.macroR[Delta]
}

case class Choices(
delta: Delta,
finishReason: String,
index: Int
)

object Choices {
implicit val choicesR: SnakePickle.Reader[Choices] = SnakePickle.macroR[Choices]
}

case class ChatChunkResponse(
id: String,
`object`: String,
created: Int,
model: String,
choices: Seq[Choices]
)

object ChatChunkResponse {
val DoneEventMessage = "[DONE]"

implicit val chunkChatR: SnakePickle.Reader[ChatChunkResponse] = SnakePickle.macroR[ChatChunkResponse]
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,21 @@ import ujson.Str

object ChatRequestBody {

/** @param role
* The role of the author of this message.
* @param content
* The contents of the message.
* @param name
* The name of the author of this message. May contain a-z, A-Z, 0-9, and underscores, with a maximum length of 64 characters.
* @param functionCall
* The name and arguments of a function that should be called, as generated by the model.
*/
case class Message(role: Role, content: String, name: Option[String] = None, functionCall: Option[FunctionCall] = None)

object Message {
implicit val messageRW: SnakePickle.ReadWriter[Message] = SnakePickle.macroRW[Message]
}

/** @param model
* ID of the model to use.
* @param messages
Expand Down Expand Up @@ -49,6 +64,12 @@ object ChatRequestBody {
)

object ChatBody {
def withStreaming(chatBody: ChatBody): ujson.Value = {
val json = SnakePickle.writeJs(chatBody)
json.obj("stream") = true
json
}

implicit val chatRequestW: SnakePickle.Writer[ChatBody] = SnakePickle.macroW[ChatBody]
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,20 @@ import sttp.openai.json.SnakePickle
import sttp.openai.requests.completions.Usage

object ChatRequestResponseData {

/** @param role
* The role of the author of this message.
* @param content
* The contents of the message.
* @param functionCall
* The name and arguments of a function that should be called, as generated by the model.
*/
case class Message(role: Role, content: String, functionCall: Option[FunctionCall] = None)

object Message {
implicit val messageRW: SnakePickle.Reader[Message] = SnakePickle.macroR[Message]
}

case class Choices(
message: Message,
finishReason: String,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package sttp.openai.requests.completions.chat

import sttp.openai.json.SnakePickle

/** @param arguments
* The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid
* JSON, and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your
* function.
* @param name
* The name of the function to call.
*/
case class FunctionCall(arguments: String, name: String)

object FunctionCall {
implicit val functionCallRW: SnakePickle.ReadWriter[FunctionCall] = SnakePickle.macroRW[FunctionCall]
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,8 @@ import sttp.openai.OpenAIExceptions.OpenAIException.DeserializationOpenAIExcepti
import sttp.openai.json.SnakePickle
import ujson.Str

/** @param role
* The role of the author of this message. One of [[Role]].
* @param content
* The contents of the message.
* @param name
* The name of the author of this message. May contain a-z, A-Z, 0-9, and underscores, with a maximum length of 64 characters.
*/
case class Message(role: Role, content: String, name: Option[String] = None)

object Message {
implicit val messageRW: SnakePickle.ReadWriter[Message] = SnakePickle.macroRW[Message]
}

sealed abstract class Role(val value: String)

object Role {
case object System extends Role("system")

Expand Down
Loading

0 comments on commit 91e7617

Please sign in to comment.