Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add batching from OpenAI #515

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mirascope/core/openai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from ._call import openai_call
from ._call import openai_call as call
from .batch import OpenAIBatch
from .call_params import OpenAICallParams
from .call_response import OpenAICallResponse
from .call_response_chunk import OpenAICallResponseChunk
Expand All @@ -26,4 +27,5 @@
"OpenAITool",
"OpenAIToolConfig",
"openai_call",
"OpenAIBatch",
]
110 changes: 110 additions & 0 deletions mirascope/core/openai/batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import json
import os
import uuid
from collections.abc import Callable

from openai import OpenAI
from openai.types import Batch
from openai.types.chat import ChatCompletionMessageParam
from pydantic import BaseModel

from ..base import BaseMessageParam
from ._utils import convert_message_params


class RequestBody(BaseModel):
model: str
messages: list[ChatCompletionMessageParam]


class Request(BaseModel):
custom_id: str
method: str
url: str
body: RequestBody


class OpenAIBatch:
"""
Example:

```python
from mirascope.core import openai, prompt_template
import time

@prompt_template("Recommend a {genre} book")
def recommend_book(genre: str): ...

openai.add_batch(recommend_book, "gpt-4o-mini", data=["fantasy", "horror"])
batch = openai.run_batch()

while batch.status != "completed":
print(f"waiting 5 seconds")
time.sleep(5)
print(openai.retrieve_batch(batch))
```
"""

job: Batch

def __init__(self, filename: str = "./batch_job.jsonl", **kwargs: dict) -> None:
self.filename = filename
self.client = OpenAI(
api_key=str(kwargs.get("OPENAI_API_KEY", os.environ["OPENAI_API_KEY"]))
)

def add(
self,
func: Callable[..., list[BaseMessageParam]],
model: str,
data: list[str],
) -> None:
tasks = []
for d in data:
request = Request(
custom_id=f"task-{uuid.uuid4()}",
method="POST",
url="/v1/chat/completions",
body=RequestBody(
model=model,
messages=convert_message_params(func(d)), # noqa: F401 #type: ignore
),
)

tasks.append(request.model_dump())

with open(self.filename, "a") as f:
for task in tasks:
f.write(json.dumps(task) + "\n")

def run(self) -> None:
with open(self.filename, "rb") as f:
batch_file = self.client.files.create(file=f, purpose="batch")

self.job = self.client.batches.create(
input_file_id=batch_file.id,
endpoint="/v1/chat/completions",
completion_window="24h",
)

def is_in_progress(self) -> bool:
self.job = self.client.batches.retrieve(self.job.id)
return self.job.status not in ["failed", "completed"]

def retrieve(self, result_file_name: str = "results.jsonl") -> list | None:
self.job = self.client.batches.retrieve(self.job.id)

if self.job.status == "completed":
result_file_id = self.job.output_file_id
if result_file_id:
result = self.client.files.content(result_file_id).content

with open(result_file_name, "wb") as file:
file.write(result)

results = []
with open(result_file_name) as file:
for line in file:
json_object = json.loads(line.strip())
results.append(json_object)
return results
165 changes: 165 additions & 0 deletions tests/core/openai/test_batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
"""Tests the `openai.add_batch` module."""

import json
import os
from unittest import TestCase
from unittest.mock import MagicMock, mock_open, patch
from uuid import UUID, uuid4

from mirascope.core import prompt_template
from mirascope.core.openai import OpenAIBatch


@prompt_template("Recommend a {genre} book")
def recommend_book(genre: str): ...


class OpenAIBatchTests(TestCase):
def test_openai_call_add_batch_creates(self) -> None:
data_genres = ["fantasy", "horror"]
expected_uuid = UUID("87654321-4321-8765-4321-876543218765")
batch_filename = f"/tmp/batch_job_{uuid4()}.jsonl"
expected = [
{
"custom_id": f"task-{expected_uuid}",
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": "gpt-4o-mini",
"messages": [
{"content": "Recommend a fantasy book", "role": "user"}
],
},
},
{
"custom_id": f"task-{expected_uuid}",
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": "gpt-4o-mini",
"messages": [
{"content": "Recommend a horror book", "role": "user"}
],
},
},
]

if os.path.exists(batch_filename):
os.remove(batch_filename)

with patch("uuid.uuid4") as mock_uuid4:
mock_uuid4.return_value = expected_uuid
batch = OpenAIBatch(batch_filename)
batch.add(recommend_book, "gpt-4o-mini", data_genres)

with open(batch_filename) as f:
assert (
json.dumps(expected[0]) + "\n" + json.dumps(expected[1]) + "\n"
== f.read()
)

if os.path.exists(batch_filename):
os.remove(batch_filename)

def test_openai_call_add_batch_appends(self) -> None:
data_genres = ["fantasy"]
expected_uuid = UUID("87654321-4321-8765-4321-876543218765")
batch_filename = f"/tmp/batch_job_{uuid4()}.jsonl"
expected = {
"custom_id": f"task-{expected_uuid}",
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": "gpt-4o-mini",
"messages": [{"content": "Recommend a fantasy book", "role": "user"}],
},
}

if os.path.exists(batch_filename):
os.remove(batch_filename)

with open(batch_filename, "w") as f:
f.write(json.dumps(expected) + "\n")

with patch("uuid.uuid4") as mock_uuid4:
mock_uuid4.return_value = expected_uuid
batch = OpenAIBatch(batch_filename)
batch.add(recommend_book, "gpt-4o-mini", data_genres)

with open(batch_filename) as f:
assert json.dumps(expected) + "\n" + json.dumps(expected) + "\n" == f.read()

if os.path.exists(batch_filename):
os.remove(batch_filename)

@patch("builtins.open", new_callable=mock_open, read_data="data")
@patch("mirascope.core.openai.batch.OpenAI", new_callable=MagicMock())
def test_openai_run_batch(self, mock_openai: MagicMock, mock_file) -> None:
mock_create = mock_openai.return_value.files.create
mock_batch_create = mock_openai.return_value.batches.create
mock_create.return_value.id = "batch_id"

batch_filename = f"/tmp/batch_job_{uuid4()}.jsonl"
batch = [
{
"custom_id": "task-0",
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": "gpt-4o-mini",
"messages": [
{"content": "Recommend a fantasy book", "role": "user"}
],
},
},
{
"custom_id": "task-1",
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": "gpt-4o-mini",
"messages": [
{"content": "Recommend a horror book", "role": "user"}
],
},
},
]

if os.path.exists(batch_filename):
os.remove(batch_filename)

with open(batch_filename, "w") as f:
for b in batch:
f.write(json.dumps(b) + "\n")

batch = OpenAIBatch(batch_filename)
batch.run()

mock_create.assert_called_once_with(file=mock_file(), purpose="batch")
mock_batch_create.assert_called_once_with(
input_file_id="batch_id",
endpoint="/v1/chat/completions",
completion_window="24h",
)
if os.path.exists(batch_filename):
os.remove(batch_filename)

@patch("builtins.open", new_callable=mock_open, read_data='{"some": "result"}')
@patch("mirascope.core.openai.batch.OpenAI", new_callable=MagicMock())
def test_openai_retrieve_batch(self, mock_openai: MagicMock, mock_file) -> None:
mock_content = mock_openai.return_value.files.content
mock_batch_create = mock_openai.return_value.batches.create
mock_batch_create.return_value.id = "batch_id"
mock_retrieve = mock_openai.return_value.batches.retrieve
mock_content.return_value.content = "content"
mock_retrieve.return_value.status = "completed"
mock_retrieve.return_value.output_file_id = "some-output-file-id"

batch = OpenAIBatch()
batch.run()
results = batch.retrieve(result_file_name="test.jsonl")

mock_retrieve.assert_called_once_with("batch_id")
mock_content.assert_called_once_with("some-output-file-id")
mock_file.assert_called()
assert results == [{"some": "result"}]