Skip to content

Commit

Permalink
add tests to azure openai encoder
Browse files Browse the repository at this point in the history
  • Loading branch information
acatav committed Jan 10, 2024
1 parent 5507b80 commit f57e8c8
Show file tree
Hide file tree
Showing 6 changed files with 127 additions and 105 deletions.
5 changes: 5 additions & 0 deletions .github/workflows/CI.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,9 @@ jobs:
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
CO_API_KEY: ${{ secrets.CO_API_KEY }}
AZURE_OPENAI_API_KEY: ${{ secrets.AZURE_OPENAI_API_KEY }}
AZURE_OPENAI_ENDPOINT: ${{ secrets.AZURE_OPENAI_ENDPOINT }}
OPENAI_API_VERSION: ${{ secrets.OPENAI_API_VERSION }}
EMBEDDINGS_AZURE_OPENAI_DEPLOYMENT_NAME: ${{ secrets.EMBEDDINGS_AZURE_OPENAI_DEPLOYMENT_NAME }}

run: poetry run pytest tests/system
2 changes: 1 addition & 1 deletion pinecone_text/dense/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .cohere_encoder import CohereEncoder
from .openai_encoder import OpenAIEncoder
from .openai_encoder import OpenAIEncoder, AzureOpenAIEncoder
from .sentence_transformer_encoder import SentenceTransformerEncoder
133 changes: 68 additions & 65 deletions pinecone_text/dense/openai_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,34 +15,43 @@ class OpenAIEncoder(BaseDenseEncoder):
"""
OpenAI's text embedding wrapper. See https://platform.openai.com/docs/guides/embeddings
Note: You should provide an API key and organization in the environment variables OPENAI_API_KEY and OPENAI_ORG.
Or you can pass them as arguments to the constructor as `api_key` and `organization`.
"""
Note: this method reflects the OpenAI client initialization behaviour (See https://github.com/openai/openai-python/blob/main/src/openai/_client.py)
On initialization, You may explicitly pass any argument that the OpenAI client accepts, or use the following environment variables:
- `OPENAI_API_KEY` as `api_key`
- `OPENAI_ORG_ID` as `organization`
- `OPENAI_BASE_URL` as `base_url`
Example:
Using environment variables:
>>> import os
>>> from pinecone_text.dense import OpenAIEncoder
>>> os.environ['OPENAI_API_KEY'] = "sk-..."
>>> encoder = OpenAIEncoder()
>>> encoder.encode_documents(["some text", "some other text"])
Passing arguments explicitly:
>>> from pinecone_text.dense import OpenAIEncoder
>>> encoder = OpenAIEncoder(api_key="sk-...")
""" # noqa: E501

def __init__(
self,
model_name: str = "text-embedding-ada-002",
api_key: Optional[str] = None,
organization: Optional[str] = None,
base_url: Optional[str] = None,
**kwargs: Any,
):
"""
Initialize the OpenAI encoder.
:param model_name: The name of the embedding model to use. See https://beta.openai.com/docs/api-reference/embeddings
:param kwargs: Additional arguments to pass to the underlying openai client. See https://github.com/openai/openai-python
"""
if not _openai_installed:
raise ImportError(
"Failed to import openai. Make sure you install openai extra "
"dependencies by running: "
"`pip install pinecone-text[openai]"
)
self._model_name = model_name
self._client = openai.OpenAI(
api_key=api_key, organization=organization, base_url=base_url, **kwargs
)
self._client = self._create_client(**kwargs)

@staticmethod
def _create_client(**kwargs: Any) -> Union[openai.OpenAI, openai.AzureOpenAI]:
return openai.OpenAI(**kwargs)

def encode_documents(
self, texts: Union[str, List[str]]
Expand All @@ -66,58 +75,52 @@ def _encode(
f"texts must be a string or list of strings, got: {type(texts)}"
)

batch_size = 16 # Azure OpenAI limit as of 2023-11-27
result = []
for i in range(0, len(texts), batch_size):
batch = texts[i : i + batch_size]
try:
response = self._client.embeddings.create(
input=batch, model=self._model_name
)
except OpenAIError as e:
# TODO: consider wrapping external provider errors
raise e

if isinstance(batch, str):
result.extend(response.data[0].embedding)
result.extend([result.embedding for result in response.data])
try:
response = self._client.embeddings.create(
input=texts_input, model=self._model_name
)
except OpenAIError as e:
# TODO: consider wrapping external provider errors
raise e

return result
if isinstance(texts, str):
return response.data[0].embedding
return [result.embedding for result in response.data]


class AzureOpenAIEncoder(OpenAIEncoder):
"""
Azure OpenAI's text embedding wrapper.
See https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/understand-embeddings
Note: You should provide an API key in the environment variables AZURE_OPENAI_API_KEY.
Or you can pass it as an arguments to the constructor as `api_key`.
"""

def __init__(
self,
model_name: str = "text-embedding-ada-002",
api_key: Optional[str] = None,
base_url: Optional[str] = None,
**kwargs: Any,
):
"""
Initialize the OpenAI encoder.
:param model_name: The name of the embedding model to use. See https://beta.openai.com/docs/api-reference/embeddings
:param kwargs: Additional arguments to pass to the underlying openai client. See https://github.com/openai/openai-python
"""
if not _openai_installed:
raise ImportError(
"Failed to import openai. Make sure you install openai extra "
"dependencies by running: "
"`pip install pinecone-text[openai]"
)
self._model_name = model_name
self._client = openai.AzureOpenAI(
api_key=api_key,
api_version="2023-05-15",
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
base_url=base_url,
**kwargs,
)
Initialize the Azure OpenAI encoder.
Note: this method reflects the AzureOpenAI client initialization behaviour (See https://github.com/openai/openai-python/blob/main/src/openai/lib/azure.py).
You may explicitly pass any argument that the AzureOpenAI client accepts, or use the following environment variables:
- `AZURE_OPENAI_API_KEY` as `api_key`
- `AZURE_OPENAI_ENDPOINT` as `azure_endpoint`
- `OPENAI_API_VERSION` as `api_version`
- `OPENAI_ORG_ID` as `organization`
- `AZURE_OPENAI_AD_TOKEN` as `azure_ad_token`
In addition, you must pass the `model_name` argument with the name of the deployment you wish to use in your own Azure account.
Example:
Using environment variables:
>>> import os
>>> from pinecone_text.dense import AzureOpenAIEncoder
>>> os.environ['AZURE_OPENAI_API_KEY'] = "sk-..."
>>> os.environ['AZURE_OPENAI_ENDPOINT'] = "https://.....openai.azure.com/"
>>> os.environ['OPENAI_API_VERSION'] = "2023-12-01-preview"
>>> encoder = AzureOpenAIEncoder(model_name="my-ada-002-deployment")
>>> encoder.encode_documents(["some text", "some other text"])
Passing arguments explicitly:
>>> from pinecone_text.dense import AzureOpenAIEncoder
>>> encoder = AzureOpenAIEncoder(api_key="sk-...", azure_endpoint="https://.....openai.azure.com/", api_version="2023-12-01-preview")
""" # noqa: E501

def __init__(self, model_name: str, **kwargs: Any):
super().__init__(model_name=model_name, **kwargs)

@staticmethod
def _create_client(**kwargs: Any) -> openai.AzureOpenAI:
return openai.AzureOpenAI(**kwargs)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "pinecone-text"
version = "0.7.1"
version = "0.7.2"
description = "Text utilities library by Pinecone.io"
authors = ["Pinecone.io"]
readme = "README.md"
Expand Down
23 changes: 16 additions & 7 deletions tests/system/test_openai_encoder.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,28 @@
import pytest
from pinecone_text.dense import OpenAIEncoder
import os
from pinecone_text.dense import OpenAIEncoder, AzureOpenAIEncoder
from openai import BadRequestError, AuthenticationError


DEFAULT_DIMENSION = 1536


@pytest.fixture
def openai_encoder():
return OpenAIEncoder()
@pytest.fixture(params=[OpenAIEncoder, AzureOpenAIEncoder])
def openai_encoder(request):
if request.param == OpenAIEncoder:
return request.param()
else:
model_name = os.environ.get("EMBEDDINGS_AZURE_OPENAI_DEPLOYMENT_NAME")
return request.param(model_name=model_name)


def test_init_with_kwargs():
encoder = OpenAIEncoder(
api_key="test_api_key", organization="test_organization", timeout=30
@pytest.mark.parametrize("encoder_class", [OpenAIEncoder, AzureOpenAIEncoder])
def test_init_with_kwargs(encoder_class):
encoder = encoder_class(
api_key="test_api_key",
organization="test_organization",
timeout=30,
model_name="test_model_name",
)
assert encoder._client.api_key == "test_api_key"
assert encoder._client.organization == "test_organization"
Expand Down
67 changes: 36 additions & 31 deletions tests/unit/test_openai_encoder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
from unittest.mock import patch, Mock
from pinecone_text.dense import OpenAIEncoder
from pinecone_text.dense import OpenAIEncoder, AzureOpenAIEncoder


def create_mock_response(embeddings):
Expand All @@ -13,10 +13,11 @@ def create_mock_response(embeddings):
mock_multiple_embeddings = create_mock_response([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]])


@pytest.fixture
def openai_encoder():
@pytest.fixture(params=[OpenAIEncoder, AzureOpenAIEncoder])
def encoder(request):
encoder_class = request.param
with patch("pinecone_text.dense.openai_encoder.openai"):
yield OpenAIEncoder()
return encoder_class(model_name="test_model_name")


def test_init_without_openai_installed():
Expand All @@ -25,21 +26,31 @@ def test_init_without_openai_installed():
OpenAIEncoder()


def test_init_with_kwargs():
@pytest.mark.parametrize("encoder_class", [OpenAIEncoder, AzureOpenAIEncoder])
def test_init_with_kwargs(encoder_class):
with patch("pinecone_text.dense.openai_encoder.openai") as mock_openai:
OpenAIEncoder(
api_key="test_api_key", organization="test_organization", timeout=30
)
mock_openai.OpenAI.assert_called_with(
encoder_class(
api_key="test_api_key",
organization="test_organization",
base_url=None,
timeout=30,
model_name="test_model_name",
)


def encode_by_type(openai_encoder, encoding_function, test_input):
func = getattr(openai_encoder, encoding_function)
if encoder_class == OpenAIEncoder:
mock_openai.OpenAI.assert_called_with(
api_key="test_api_key",
organization="test_organization",
timeout=30,
)
else:
mock_openai.AzureOpenAI.assert_called_with(
api_key="test_api_key",
organization="test_organization",
timeout=30,
)


def encode_by_type(encoder, encoding_function, test_input):
func = getattr(encoder, encoding_function)
return func(test_input)


Expand All @@ -50,12 +61,10 @@ def encode_by_type(openai_encoder, encoding_function, test_input):
("encode_queries"),
],
)
def test_encode_single_text(openai_encoder, encoding_function):
with patch.object(
openai_encoder._client, "embeddings", create=True
) as mock_embeddings:
def test_encode_single_text(encoder, encoding_function):
with patch.object(encoder._client, "embeddings", create=True) as mock_embeddings:
mock_embeddings.create.return_value = mock_single_embedding
result = encode_by_type(openai_encoder, encoding_function, "test text")
result = encode_by_type(encoder, encoding_function, "test text")
assert result == [0.1, 0.2, 0.3]


Expand All @@ -66,12 +75,10 @@ def test_encode_single_text(openai_encoder, encoding_function):
("encode_queries"),
],
)
def test_encode_multiple_texts(openai_encoder, encoding_function):
with patch.object(
openai_encoder._client, "embeddings", create=True
) as mock_embeddings:
def test_encode_multiple_texts(encoder, encoding_function):
with patch.object(encoder._client, "embeddings", create=True) as mock_embeddings:
mock_embeddings.create.return_value = mock_multiple_embeddings
result = encode_by_type(openai_encoder, encoding_function, ["text1", "text2"])
result = encode_by_type(encoder, encoding_function, ["text1", "text2"])
assert result == [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]


Expand All @@ -82,9 +89,9 @@ def test_encode_multiple_texts(openai_encoder, encoding_function):
("encode_queries"),
],
)
def test_encode_invalid_input(openai_encoder, encoding_function):
def test_encode_invalid_input(encoder, encoding_function):
with pytest.raises(ValueError):
encode_by_type(openai_encoder, encoding_function, 123)
encode_by_type(encoder, encoding_function, 123)


@pytest.mark.parametrize(
Expand All @@ -94,10 +101,8 @@ def test_encode_invalid_input(openai_encoder, encoding_function):
("encode_queries"),
],
)
def test_encode_error_handling(openai_encoder, encoding_function):
with patch.object(
openai_encoder._client, "embeddings", create=True
) as mock_embeddings:
def test_encode_error_handling(encoder, encoding_function):
with patch.object(encoder._client, "embeddings", create=True) as mock_embeddings:
mock_embeddings.create.side_effect = ValueError("OpenAI API error")
with pytest.raises(ValueError, match="OpenAI API error"):
encode_by_type(openai_encoder, encoding_function, "test text")
encode_by_type(encoder, encoding_function, "test text")

0 comments on commit f57e8c8

Please sign in to comment.