From b7849cc3abc328203333052f25a722cb668d116c Mon Sep 17 00:00:00 2001 From: Daiyi Peng Date: Sun, 5 Jan 2025 14:19:04 -0800 Subject: [PATCH] Unifying GenAI and VertexAI based on the shared Gemini REST API. By extracting 'gemini.py' to deal with Gemini's REST API, we are able to share the most parts of Google GenAI and VertexAI. This CL establishes us to support new Gemini models/features with a single copy of changes. Also we now have all our LLMs implemented in REST, minimizing dependencies needed when using Langfun. PiperOrigin-RevId: 712314155 --- README.md | 8 +- langfun/core/llms/__init__.py | 47 ++- langfun/core/llms/gemini.py | 500 +++++++++++++++++++++++++ langfun/core/llms/gemini_test.py | 190 ++++++++++ langfun/core/llms/google_genai.py | 365 +++--------------- langfun/core/llms/google_genai_test.py | 213 +---------- langfun/core/llms/vertexai.py | 382 ++----------------- langfun/core/llms/vertexai_test.py | 172 +-------- requirements.txt | 4 +- 9 files changed, 801 insertions(+), 1080 deletions(-) create mode 100644 langfun/core/llms/gemini.py create mode 100644 langfun/core/llms/gemini_test.py diff --git a/README.md b/README.md index aaa04ead..553dce36 100644 --- a/README.md +++ b/README.md @@ -138,9 +138,7 @@ If you want to customize your installation, you can select specific features usi | Tag | Description | | ------------------- | ---------------------------------------- | | all | All Langfun features. | -| llm | All supported LLMs. | -| llm-google | All supported Google-powered LLMs. | -| llm-google-genai | LLMs powered by Google Generative AI API | +| vertexai | VertexAI access. | | mime | All MIME supports. | | mime-auto | Automatic MIME type detection. | | mime-docx | DocX format support. | @@ -149,9 +147,9 @@ If you want to customize your installation, you can select specific features usi | ui | UI enhancements | -For example, to install a nightly build that includes Google-powered LLMs, full modality support, and UI enhancements, use: +For example, to install a nightly build that includes VertexAI access, full modality support, and UI enhancements, use: ``` -pip install langfun[llm-google,mime,ui] --pre +pip install langfun[vertexai,mime,ui] --pre ``` *Disclaimer: this is not an officially supported Google product.* diff --git a/langfun/core/llms/__init__.py b/langfun/core/llms/__init__.py index 5e390289..84923dd8 100644 --- a/langfun/core/llms/__init__.py +++ b/langfun/core/llms/__init__.py @@ -32,16 +32,30 @@ # Gemini models. from langfun.core.llms.google_genai import GenAI -from langfun.core.llms.google_genai import GeminiFlash2_0ThinkingExp +from langfun.core.llms.google_genai import GeminiFlash2_0ThinkingExp_20241219 from langfun.core.llms.google_genai import GeminiFlash2_0Exp -from langfun.core.llms.google_genai import GeminiExp_20241114 from langfun.core.llms.google_genai import GeminiExp_20241206 -from langfun.core.llms.google_genai import GeminiFlash1_5 +from langfun.core.llms.google_genai import GeminiExp_20241114 from langfun.core.llms.google_genai import GeminiPro1_5 -from langfun.core.llms.google_genai import GeminiPro -from langfun.core.llms.google_genai import GeminiProVision -from langfun.core.llms.google_genai import Palm2 -from langfun.core.llms.google_genai import Palm2_IT +from langfun.core.llms.google_genai import GeminiPro1_5_002 +from langfun.core.llms.google_genai import GeminiPro1_5_001 +from langfun.core.llms.google_genai import GeminiFlash1_5 +from langfun.core.llms.google_genai import GeminiFlash1_5_002 +from langfun.core.llms.google_genai import GeminiFlash1_5_001 +from langfun.core.llms.google_genai import GeminiPro1 + +from langfun.core.llms.vertexai import VertexAI +from langfun.core.llms.vertexai import VertexAIGeminiFlash2_0ThinkingExp_20241219 +from langfun.core.llms.vertexai import VertexAIGeminiFlash2_0Exp +from langfun.core.llms.vertexai import VertexAIGeminiExp_20241206 +from langfun.core.llms.vertexai import VertexAIGeminiExp_20241114 +from langfun.core.llms.vertexai import VertexAIGeminiPro1_5 +from langfun.core.llms.vertexai import VertexAIGeminiPro1_5_002 +from langfun.core.llms.vertexai import VertexAIGeminiPro1_5_001 +from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5 +from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5_002 +from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5_001 +from langfun.core.llms.vertexai import VertexAIGeminiPro1 # OpenAI models. from langfun.core.llms.openai import OpenAI @@ -124,25 +138,6 @@ from langfun.core.llms.groq import GroqWhisper_Large_v3 from langfun.core.llms.groq import GroqWhisper_Large_v3Turbo -from langfun.core.llms.vertexai import VertexAI -from langfun.core.llms.vertexai import VertexAIGemini2_0 -from langfun.core.llms.vertexai import VertexAIGeminiFlash2_0Exp -from langfun.core.llms.vertexai import VertexAIGeminiFlash2_0ThinkingExp -from langfun.core.llms.vertexai import VertexAIGemini1_5 -from langfun.core.llms.vertexai import VertexAIGeminiPro1_5 -from langfun.core.llms.vertexai import VertexAIGeminiPro1_5_001 -from langfun.core.llms.vertexai import VertexAIGeminiPro1_5_002 -from langfun.core.llms.vertexai import VertexAIGeminiPro1_5_0514 -from langfun.core.llms.vertexai import VertexAIGeminiPro1_5_0409 -from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5 -from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5_001 -from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5_002 -from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5_0514 -from langfun.core.llms.vertexai import VertexAIGeminiPro1 -from langfun.core.llms.vertexai import VertexAIGeminiPro1Vision -from langfun.core.llms.vertexai import VertexAIEndpoint - - # LLaMA C++ models. from langfun.core.llms.llama_cpp import LlamaCppRemote diff --git a/langfun/core/llms/gemini.py b/langfun/core/llms/gemini.py new file mode 100644 index 00000000..f5bbbdaf --- /dev/null +++ b/langfun/core/llms/gemini.py @@ -0,0 +1,500 @@ +# Copyright 2025 The Langfun Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Gemini REST API (Shared by Google GenAI and Vertex AI).""" + +import base64 +from typing import Any + +import langfun.core as lf +from langfun.core import modalities as lf_modalities +from langfun.core.llms import rest +import pyglove as pg + +# Supported modalities. + +IMAGE_TYPES = [ + 'image/png', + 'image/jpeg', + 'image/webp', + 'image/heic', + 'image/heif', +] + +AUDIO_TYPES = [ + 'audio/aac', + 'audio/flac', + 'audio/mp3', + 'audio/m4a', + 'audio/mpeg', + 'audio/mpga', + 'audio/mp4', + 'audio/opus', + 'audio/pcm', + 'audio/wav', + 'audio/webm', +] + +VIDEO_TYPES = [ + 'video/mov', + 'video/mpeg', + 'video/mpegps', + 'video/mpg', + 'video/mp4', + 'video/webm', + 'video/wmv', + 'video/x-flv', + 'video/3gpp', + 'video/quicktime', +] + +DOCUMENT_TYPES = [ + 'application/pdf', + 'text/plain', + 'text/csv', + 'text/html', + 'text/xml', + 'text/x-script.python', + 'application/json', +] + +TEXT_ONLY = [] + +ALL_MODALITIES = ( + IMAGE_TYPES + AUDIO_TYPES + VIDEO_TYPES + DOCUMENT_TYPES +) + +SUPPORTED_MODELS_AND_SETTINGS = { + # For automatically rate control and cost estimation, we explicitly register + # supported models here. This may be inconvenient for new models, but it + # helps us to keep track of the models and their pricing. + # Models and RPM are from + # https://ai.google.dev/gemini-api/docs/models/gemini?_gl=1*114hbho*_up*MQ..&gclid=Cj0KCQiAst67BhCEARIsAKKdWOljBY5aQdNQ41zOPkXFCwymUfMNFl_7ukm1veAf75ZTD9qWFrFr11IaApL3EALw_wcB + # Pricing in US dollars, from https://ai.google.dev/pricing + # as of 2025-01-03. + # NOTE: Please update google_genai.py, vertexai.py, __init__.py when + # adding new models. + # !!! PLEASE KEEP MODELS SORTED BY RELEASE DATE !!! + 'gemini-2.0-flash-thinking-exp-1219': pg.Dict( + latest_update='2024-12-19', + experimental=True, + in_service=True, + supported_modalities=ALL_MODALITIES, + rpm_free=10, + tpm_free=4_000_000, + rpm_paid=0, + tpm_paid=0, + cost_per_1m_input_tokens_up_to_128k=0, + cost_per_1m_output_tokens_up_to_128k=0, + cost_per_1m_cached_tokens_up_to_128k=0, + cost_per_1m_input_tokens_longer_than_128k=0, + cost_per_1m_output_tokens_longer_than_128k=0, + cost_per_1m_cached_tokens_longer_than_128k=0, + ), + 'gemini-2.0-flash-exp': pg.Dict( + latest_update='2024-12-11', + experimental=True, + in_service=True, + supported_modalities=ALL_MODALITIES, + rpm_free=10, + tpm_free=4_000_000, + rpm_paid=0, + tpm_paid=0, + cost_per_1m_input_tokens_up_to_128k=0, + cost_per_1m_output_tokens_up_to_128k=0, + cost_per_1m_cached_tokens_up_to_128k=0, + cost_per_1m_input_tokens_longer_than_128k=0, + cost_per_1m_output_tokens_longer_than_128k=0, + cost_per_1m_cached_tokens_longer_than_128k=0, + ), + 'gemini-exp-1206': pg.Dict( + latest_update='2024-12-06', + experimental=True, + in_service=True, + supported_modalities=ALL_MODALITIES, + rpm_free=10, + tpm_free=4_000_000, + rpm_paid=0, + tpm_paid=0, + cost_per_1m_input_tokens_up_to_128k=0, + cost_per_1m_output_tokens_up_to_128k=0, + cost_per_1m_cached_tokens_up_to_128k=0, + cost_per_1m_input_tokens_longer_than_128k=0, + cost_per_1m_output_tokens_longer_than_128k=0, + cost_per_1m_cached_tokens_longer_than_128k=0, + ), + 'learnlm-1.5-pro-experimental': pg.Dict( + latest_update='2024-11-19', + experimental=True, + in_service=True, + supported_modalities=ALL_MODALITIES, + rpm_free=10, + tpm_free=4_000_000, + rpm_paid=0, + tpm_paid=0, + cost_per_1m_input_tokens_up_to_128k=0, + cost_per_1m_output_tokens_up_to_128k=0, + cost_per_1m_cached_tokens_up_to_128k=0, + cost_per_1m_input_tokens_longer_than_128k=0, + cost_per_1m_output_tokens_longer_than_128k=0, + cost_per_1m_cached_tokens_longer_than_128k=0, + ), + 'gemini-exp-1114': pg.Dict( + latest_update='2024-11-14', + experimental=True, + in_service=True, + supported_modalities=ALL_MODALITIES, + rpm_free=10, + tpm_free=4_000_000, + rpm_paid=0, + tpm_paid=0, + cost_per_1m_input_tokens_up_to_128k=0, + cost_per_1m_output_tokens_up_to_128k=0, + cost_per_1m_cached_tokens_up_to_128k=0, + cost_per_1m_input_tokens_longer_than_128k=0, + cost_per_1m_output_tokens_longer_than_128k=0, + cost_per_1m_cached_tokens_longer_than_128k=0, + ), + 'gemini-1.5-flash-latest': pg.Dict( + latest_update='2024-09-30', + in_service=True, + supported_modalities=ALL_MODALITIES, + rpm_free=15, + tpm_free=1_000_000, + rpm_paid=2000, + tpm_paid=4_000_000, + cost_per_1m_input_tokens_up_to_128k=0.075, + cost_per_1m_output_tokens_up_to_128k=0.3, + cost_per_1m_cached_tokens_up_to_128k=0.01875, + cost_per_1m_input_tokens_longer_than_128k=0.15, + cost_per_1m_output_tokens_longer_than_128k=0.6, + cost_per_1m_cached_tokens_longer_than_128k=0.0375, + ), + 'gemini-1.5-flash': pg.Dict( + latest_update='2024-09-30', + in_service=True, + supported_modalities=ALL_MODALITIES, + rpm_free=15, + tpm_free=1_000_000, + rpm_paid=2000, + tpm_paid=4_000_000, + cost_per_1m_input_tokens_up_to_128k=0.075, + cost_per_1m_output_tokens_up_to_128k=0.3, + cost_per_1m_cached_tokens_up_to_128k=0.01875, + cost_per_1m_input_tokens_longer_than_128k=0.15, + cost_per_1m_output_tokens_longer_than_128k=0.6, + cost_per_1m_cached_tokens_longer_than_128k=0.0375, + ), + 'gemini-1.5-flash-001': pg.Dict( + latest_update='2024-09-30', + in_service=True, + supported_modalities=ALL_MODALITIES, + rpm_free=15, + tpm_free=1_000_000, + rpm_paid=2000, + tpm_paid=4_000_000, + cost_per_1m_input_tokens_up_to_128k=0.075, + cost_per_1m_output_tokens_up_to_128k=0.3, + cost_per_1m_cached_tokens_up_to_128k=0.01875, + cost_per_1m_input_tokens_longer_than_128k=0.15, + cost_per_1m_output_tokens_longer_than_128k=0.6, + cost_per_1m_cached_tokens_longer_than_128k=0.0375, + ), + 'gemini-1.5-flash-002': pg.Dict( + latest_update='2024-09-30', + in_service=True, + supported_modalities=ALL_MODALITIES, + rpm_free=15, + tpm_free=1_000_000, + rpm_paid=2000, + tpm_paid=4_000_000, + cost_per_1m_input_tokens_up_to_128k=0.075, + cost_per_1m_output_tokens_up_to_128k=0.3, + cost_per_1m_cached_tokens_up_to_128k=0.01875, + cost_per_1m_input_tokens_longer_than_128k=0.15, + cost_per_1m_output_tokens_longer_than_128k=0.6, + cost_per_1m_cached_tokens_longer_than_128k=0.0375, + ), + 'gemini-1.5-flash-8b': pg.Dict( + latest_update='2024-10-30', + in_service=True, + supported_modalities=ALL_MODALITIES, + rpm_free=15, + tpm_free=1_000_000, + rpm_paid=4000, + tpm_paid=4_000_000, + cost_per_1m_input_tokens_up_to_128k=0.0375, + cost_per_1m_output_tokens_up_to_128k=0.15, + cost_per_1m_cached_tokens_up_to_128k=0.01, + cost_per_1m_input_tokens_longer_than_128k=0.075, + cost_per_1m_output_tokens_longer_than_128k=0.3, + cost_per_1m_cached_tokens_longer_than_128k=0.02, + ), + 'gemini-1.5-flash-8b-001': pg.Dict( + latest_update='2024-10-30', + in_service=True, + supported_modalities=ALL_MODALITIES, + rpm_free=15, + tpm_free=1_000_000, + rpm_paid=4000, + tpm_paid=4_000_000, + cost_per_1m_input_tokens_up_to_128k=0.0375, + cost_per_1m_output_tokens_up_to_128k=0.15, + cost_per_1m_cached_tokens_up_to_128k=0.01, + cost_per_1m_input_tokens_longer_than_128k=0.075, + cost_per_1m_output_tokens_longer_than_128k=0.3, + cost_per_1m_cached_tokens_longer_than_128k=0.02, + ), + 'gemini-1.5-pro-latest': pg.Dict( + latest_update='2024-09-30', + in_service=True, + supported_modalities=ALL_MODALITIES, + rpm_free=2, + tpm_free=32_000, + rpm_paid=1000, + tpm_paid=4_000_000, + cost_per_1m_input_tokens_up_to_128k=1.25, + cost_per_1m_output_tokens_up_to_128k=5.00, + cost_per_1m_cached_tokens_up_to_128k=0.3125, + cost_per_1m_input_tokens_longer_than_128k=2.5, + cost_per_1m_output_tokens_longer_than_128k=10.00, + cost_per_1m_cached_tokens_longer_than_128k=0.625, + ), + 'gemini-1.5-pro': pg.Dict( + latest_update='2024-09-30', + in_service=True, + supported_modalities=ALL_MODALITIES, + rpm_free=2, + tpm_free=32_000, + rpm_paid=1000, + tpm_paid=4_000_000, + cost_per_1m_input_tokens_up_to_128k=1.25, + cost_per_1m_output_tokens_up_to_128k=5.00, + cost_per_1m_cached_tokens_up_to_128k=0.3125, + cost_per_1m_input_tokens_longer_than_128k=2.5, + cost_per_1m_output_tokens_longer_than_128k=10.00, + cost_per_1m_cached_tokens_longer_than_128k=0.625, + ), + 'gemini-1.5-pro-001': pg.Dict( + latest_update='2024-09-30', + in_service=True, + supported_modalities=ALL_MODALITIES, + rpm_free=2, + tpm_free=32_000, + rpm_paid=1000, + tpm_paid=4_000_000, + cost_per_1m_input_tokens_up_to_128k=1.25, + cost_per_1m_output_tokens_up_to_128k=5.00, + cost_per_1m_cached_tokens_up_to_128k=0.3125, + cost_per_1m_input_tokens_longer_than_128k=2.5, + cost_per_1m_output_tokens_longer_than_128k=10.00, + cost_per_1m_cached_tokens_longer_than_128k=0.625, + ), + 'gemini-1.5-pro-002': pg.Dict( + latest_update='2024-09-30', + in_service=True, + supported_modalities=ALL_MODALITIES, + rpm_free=2, + tpm_free=32_000, + rpm_paid=1000, + tpm_paid=4_000_000, + cost_per_1m_input_tokens_up_to_128k=1.25, + cost_per_1m_output_tokens_up_to_128k=5.00, + cost_per_1m_cached_tokens_up_to_128k=0.3125, + cost_per_1m_input_tokens_longer_than_128k=2.5, + cost_per_1m_output_tokens_longer_than_128k=10.00, + cost_per_1m_cached_tokens_longer_than_128k=0.625, + ), + 'gemini-1.0-pro': pg.Dict( + in_service=False, + supported_modalities=TEXT_ONLY, + rpm_free=15, + tpm_free=32_000, + rpm_paid=360, + tpm_paid=120_000, + cost_per_1m_input_tokens_up_to_128k=0.5, + cost_per_1m_output_tokens_up_to_128k=1.5, + cost_per_1m_cached_tokens_up_to_128k=0, + cost_per_1m_input_tokens_longer_than_128k=0.5, + cost_per_1m_output_tokens_longer_than_128k=1.5, + cost_per_1m_cached_tokens_longer_than_128k=0, + ), +} + + +@pg.use_init_args(['model']) +class Gemini(rest.REST): + """Language models provided by Google GenAI.""" + + model: pg.typing.Annotated[ + pg.typing.Enum( + pg.MISSING_VALUE, list(SUPPORTED_MODELS_AND_SETTINGS.keys()) + ), + 'The name of the model to use.', + ] + + @property + def supported_modalities(self) -> list[str]: + """Returns the list of supported modalities.""" + return SUPPORTED_MODELS_AND_SETTINGS[self.model].supported_modalities + + @property + def max_concurrency(self) -> int: + """Returns the maximum number of concurrent requests.""" + return self.rate_to_max_concurrency( + requests_per_min=max( + SUPPORTED_MODELS_AND_SETTINGS[self.model].rpm_free, + SUPPORTED_MODELS_AND_SETTINGS[self.model].rpm_paid + ), + tokens_per_min=max( + SUPPORTED_MODELS_AND_SETTINGS[self.model].tpm_free, + SUPPORTED_MODELS_AND_SETTINGS[self.model].tpm_paid, + ), + ) + + def estimate_cost( + self, + num_input_tokens: int, + num_output_tokens: int + ) -> float | None: + """Estimate the cost based on usage.""" + entry = SUPPORTED_MODELS_AND_SETTINGS[self.model] + if num_input_tokens < 128_000: + cost_per_1m_input_tokens = entry.cost_per_1m_input_tokens_up_to_128k + cost_per_1m_output_tokens = entry.cost_per_1m_output_tokens_up_to_128k + else: + cost_per_1m_input_tokens = entry.cost_per_1m_input_tokens_longer_than_128k + cost_per_1m_output_tokens = ( + entry.cost_per_1m_output_tokens_longer_than_128k + ) + return ( + cost_per_1m_input_tokens * num_input_tokens + + cost_per_1m_output_tokens * num_output_tokens + ) / 1000_1000 + + @property + def model_id(self) -> str: + """Returns a string to identify the model.""" + return self.model + + @classmethod + def dir(cls): + return [k for k, v in SUPPORTED_MODELS_AND_SETTINGS.items() if v.in_service] + + @property + def headers(self): + return { + 'Content-Type': 'application/json; charset=utf-8', + } + + def request( + self, prompt: lf.Message, sampling_options: lf.LMSamplingOptions + ) -> dict[str, Any]: + request = dict( + generationConfig=self._generation_config(prompt, sampling_options) + ) + request['contents'] = [self._content_from_message(prompt)] + return request + + def _generation_config( + self, prompt: lf.Message, options: lf.LMSamplingOptions + ) -> dict[str, Any]: + """Returns a dict as generation config for prompt and LMSamplingOptions.""" + config = dict( + temperature=options.temperature, + maxOutputTokens=options.max_tokens, + candidateCount=options.n, + topK=options.top_k, + topP=options.top_p, + stopSequences=options.stop, + seed=options.random_seed, + responseLogprobs=options.logprobs, + logprobs=options.top_logprobs, + ) + + if json_schema := prompt.metadata.get('json_schema'): + if not isinstance(json_schema, dict): + raise ValueError( + f'`json_schema` must be a dict, got {json_schema!r}.' + ) + json_schema = pg.to_json(json_schema) + config['responseSchema'] = json_schema + config['responseMimeType'] = 'application/json' + prompt.metadata.formatted_text = ( + prompt.text + + '\n\n [RESPONSE FORMAT (not part of prompt)]\n' + + pg.to_json_str(json_schema, json_indent=2) + ) + return config + + def _content_from_message(self, prompt: lf.Message) -> dict[str, Any]: + """Gets generation content from langfun message.""" + parts = [] + for lf_chunk in prompt.chunk(): + if isinstance(lf_chunk, str): + parts.append({'text': lf_chunk}) + elif isinstance(lf_chunk, lf_modalities.Mime): + try: + modalities = lf_chunk.make_compatible( + self.supported_modalities + ['text/plain'] + ) + if isinstance(modalities, lf_modalities.Mime): + modalities = [modalities] + for modality in modalities: + if modality.is_text: + parts.append({'text': modality.to_text()}) + else: + parts.append({ + 'inlineData': { + 'data': base64.b64encode(modality.to_bytes()).decode(), + 'mimeType': modality.mime_type, + } + }) + except lf.ModalityError as e: + raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}') from e + else: + raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}') + return dict(role='user', parts=parts) + + def result(self, json: dict[str, Any]) -> lf.LMSamplingResult: + messages = [ + self._message_from_content_parts(candidate['content']['parts']) + for candidate in json['candidates'] + ] + usage = json['usageMetadata'] + input_tokens = usage['promptTokenCount'] + output_tokens = usage['candidatesTokenCount'] + return lf.LMSamplingResult( + [lf.LMSample(message) for message in messages], + usage=lf.LMSamplingUsage( + prompt_tokens=input_tokens, + completion_tokens=output_tokens, + total_tokens=input_tokens + output_tokens, + estimated_cost=self.estimate_cost( + num_input_tokens=input_tokens, + num_output_tokens=output_tokens, + ), + ), + ) + + def _message_from_content_parts( + self, parts: list[dict[str, Any]] + ) -> lf.Message: + """Converts Vertex AI's content parts protocol to message.""" + chunks = [] + for part in parts: + if text_part := part.get('text'): + chunks.append(text_part) + else: + raise ValueError(f'Unsupported part: {part}') + return lf.AIMessage.from_chunks(chunks) diff --git a/langfun/core/llms/gemini_test.py b/langfun/core/llms/gemini_test.py new file mode 100644 index 00000000..4f826966 --- /dev/null +++ b/langfun/core/llms/gemini_test.py @@ -0,0 +1,190 @@ +# Copyright 2025 The Langfun Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for Gemini API.""" + +import base64 +from typing import Any +import unittest +from unittest import mock + +import langfun.core as lf +from langfun.core import modalities as lf_modalities +from langfun.core.llms import gemini +import pyglove as pg +import requests + + +example_image = ( + b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x18\x00\x00\x00\x18\x04' + b'\x03\x00\x00\x00\x12Y \xcb\x00\x00\x00\x18PLTE\x00\x00' + b'\x00fff_chaag_cg_ch^ci_ciC\xedb\x94\x00\x00\x00\x08tRNS' + b'\x00\n\x9f*\xd4\xff_\xf4\xe4\x8b\xf3a\x00\x00\x00>IDATx' + b'\x01c \x05\x08)"\xd8\xcc\xae!\x06pNz\x88k\x19\\Q\xa8"\x10' + b'\xc1\x14\x95\x01%\xc1\n\xa143Ta\xa8"D-\x84\x03QM\x98\xc3' + b'\x1a\x1a\x1a@5\x0e\x04\xa0q\x88\x05\x00\x07\xf8\x18\xf9' + b'\xdao\xd0|\x00\x00\x00\x00IEND\xaeB`\x82' +) + + +def mock_requests_post(url: str, json: dict[str, Any], **kwargs): + del url, kwargs + c = pg.Dict(json['generationConfig']) + content = json['contents'][0]['parts'][0]['text'] + response = requests.Response() + response.status_code = 200 + response._content = pg.to_json_str({ + 'candidates': [ + { + 'content': { + 'role': 'model', + 'parts': [ + { + 'text': ( + f'This is a response to {content} with ' + f'temperature={c.temperature}, ' + f'top_p={c.topP}, ' + f'top_k={c.topK}, ' + f'max_tokens={c.maxOutputTokens}, ' + f'stop={"".join(c.stopSequences)}.' + ) + }, + ], + }, + }, + ], + 'usageMetadata': { + 'promptTokenCount': 3, + 'candidatesTokenCount': 4, + } + }).encode() + return response + + +class GeminiTest(unittest.TestCase): + """Tests for Vertex model with REST API.""" + + def test_content_from_message_text_only(self): + text = 'This is a beautiful day' + model = gemini.Gemini('gemini-1.5-pro', api_endpoint='') + chunks = model._content_from_message(lf.UserMessage(text)) + self.assertEqual(chunks, {'role': 'user', 'parts': [{'text': text}]}) + + def test_content_from_message_mm(self): + image = lf_modalities.Image.from_bytes(example_image) + message = lf.UserMessage( + 'This is an <<[[image]]>>, what is it?', image=image + ) + + # Non-multimodal model. + with self.assertRaisesRegex(lf.ModalityError, 'Unsupported modality'): + gemini.Gemini( + 'gemini-1.0-pro', api_endpoint='' + )._content_from_message(message) + + model = gemini.Gemini('gemini-1.5-pro', api_endpoint='') + content = model._content_from_message(message) + self.assertEqual( + content, + { + 'role': 'user', + 'parts': [ + {'text': 'This is an'}, + { + 'inlineData': { + 'data': base64.b64encode(example_image).decode(), + 'mimeType': 'image/png', + } + }, + {'text': ', what is it?'}, + ], + }, + ) + + def test_generation_config(self): + model = gemini.Gemini('gemini-1.5-pro', api_endpoint='') + json_schema = { + 'type': 'object', + 'properties': { + 'name': {'type': 'string'}, + }, + 'required': ['name'], + 'title': 'Person', + } + actual = model._generation_config( + lf.UserMessage('hi', json_schema=json_schema), + lf.LMSamplingOptions( + temperature=2.0, + top_p=1.0, + top_k=20, + max_tokens=1024, + stop=['\n'], + ), + ) + self.assertEqual( + actual, + dict( + candidateCount=1, + temperature=2.0, + topP=1.0, + topK=20, + maxOutputTokens=1024, + stopSequences=['\n'], + responseLogprobs=False, + logprobs=None, + seed=None, + responseMimeType='application/json', + responseSchema={ + 'type': 'object', + 'properties': { + 'name': {'type': 'string'} + }, + 'required': ['name'], + 'title': 'Person', + } + ), + ) + with self.assertRaisesRegex( + ValueError, '`json_schema` must be a dict, got' + ): + model._generation_config( + lf.UserMessage('hi', json_schema='not a dict'), + lf.LMSamplingOptions(), + ) + + def test_call_model(self): + with mock.patch('requests.Session.post') as mock_generate: + mock_generate.side_effect = mock_requests_post + + lm = gemini.Gemini('gemini-1.5-pro', api_endpoint='') + r = lm( + 'hello', + temperature=2.0, + top_p=1.0, + top_k=20, + max_tokens=1024, + stop='\n', + ) + self.assertEqual( + r.text, + ( + 'This is a response to hello with temperature=2.0, ' + 'top_p=1.0, top_k=20, max_tokens=1024, stop=\n.' + ), + ) + self.assertEqual(r.metadata.usage.prompt_tokens, 3) + self.assertEqual(r.metadata.usage.completion_tokens, 4) + + +if __name__ == '__main__': + unittest.main() diff --git a/langfun/core/llms/google_genai.py b/langfun/core/llms/google_genai.py index 32af0eeb..4b454871 100644 --- a/langfun/core/llms/google_genai.py +++ b/langfun/core/llms/google_genai.py @@ -1,4 +1,4 @@ -# Copyright 2024 The Langfun Authors +# Copyright 2025 The Langfun Authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,57 +11,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Gemini models exposed through Google Generative AI APIs.""" +"""Language models from Google GenAI.""" -import abc -import functools import os -from typing import Annotated, Any, Literal +from typing import Annotated, Literal import langfun.core as lf -from langfun.core import modalities as lf_modalities -from langfun.core.llms import vertexai +from langfun.core.llms import gemini import pyglove as pg -try: - import google.generativeai as genai # pylint: disable=g-import-not-at-top - BlobDict = genai.types.BlobDict - GenerativeModel = genai.GenerativeModel - Completion = getattr(genai.types, 'Completion', Any) - ChatResponse = getattr(genai.types, 'ChatResponse', Any) - GenerateContentResponse = getattr(genai.types, 'GenerateContentResponse', Any) - GenerationConfig = genai.GenerationConfig -except ImportError: - genai = None - BlobDict = Any - GenerativeModel = Any - Completion = Any - ChatResponse = Any - GenerationConfig = Any - GenerateContentResponse = Any - - @lf.use_init_args(['model']) -class GenAI(lf.LanguageModel): +@pg.members([('api_endpoint', pg.typing.Str().freeze(''))]) +class GenAI(gemini.Gemini): """Language models provided by Google GenAI.""" - model: Annotated[ - Literal[ - 'gemini-2.0-flash-thinking-exp-1219', - 'gemini-2.0-flash-exp', - 'gemini-exp-1206', - 'gemini-exp-1114', - 'gemini-1.5-pro-latest', - 'gemini-1.5-flash-latest', - 'gemini-pro', - 'gemini-pro-vision', - 'text-bison-001', - 'chat-bison-001', - ], - 'Model name.', - ] - api_key: Annotated[ str | None, ( @@ -71,26 +35,18 @@ class GenAI(lf.LanguageModel): ), ] = None - supported_modalities: Annotated[ - list[str], - 'A list of MIME types for supported modalities' - ] = [] + api_version: Annotated[ + Literal['v1beta', 'v1alpha'], + 'The API version to use.' + ] = 'v1beta' - # Set the default max concurrency to 8 workers. - max_concurrency = 8 - - def _on_bound(self): - super()._on_bound() - if genai is None: - raise RuntimeError( - 'Please install "langfun[llm-google-genai]" to use ' - 'Google Generative AI models.' - ) - self.__dict__.pop('_api_initialized', None) + @property + def model_id(self) -> str: + """Returns a string to identify the model.""" + return f'GenAI({self.model})' - @functools.cached_property - def _api_initialized(self): - assert genai is not None + @property + def api_endpoint(self) -> str: api_key = self.api_key or os.environ.get('GOOGLE_API_KEY', None) if not api_key: raise ValueError( @@ -100,306 +56,75 @@ def _api_initialized(self): 'https://cloud.google.com/api-keys/docs/create-manage-api-keys ' 'for more details.' ) - genai.configure(api_key=api_key) - return True - - @classmethod - def dir(cls) -> list[str]: - """Lists generative models.""" - assert genai is not None - return [ - m.name.lstrip('models/') - for m in genai.list_models() - if ( - 'generateContent' in m.supported_generation_methods - or 'generateText' in m.supported_generation_methods - or 'generateMessage' in m.supported_generation_methods - ) - ] - - @property - def model_id(self) -> str: - """Returns a string to identify the model.""" - return self.model - - @property - def resource_id(self) -> str: - """Returns a string to identify the resource for rate control.""" - return self.model_id - - def _generation_config(self, options: lf.LMSamplingOptions) -> dict[str, Any]: - """Creates generation config from langfun sampling options.""" - return GenerationConfig( - candidate_count=options.n, - temperature=options.temperature, - top_p=options.top_p, - top_k=options.top_k, - max_output_tokens=options.max_tokens, - stop_sequences=options.stop, + return ( + f'https://generativelanguage.googleapis.com/{self.api_version}' + f'/models/{self.model}:generateContent?' + f'key={api_key}' ) - def _content_from_message( - self, prompt: lf.Message - ) -> list[str | BlobDict]: - """Gets Evergreen formatted content from langfun message.""" - formatted = lf.UserMessage(prompt.text) - formatted.source = prompt - - chunks = [] - for lf_chunk in formatted.chunk(): - if isinstance(lf_chunk, str): - chunks.append(lf_chunk) - elif isinstance(lf_chunk, lf_modalities.Mime): - try: - modalities = lf_chunk.make_compatible( - self.supported_modalities + ['text/plain'] - ) - if isinstance(modalities, lf_modalities.Mime): - modalities = [modalities] - for modality in modalities: - if modality.is_text: - chunk = modality.to_text() - else: - chunk = BlobDict( - data=modality.to_bytes(), - mime_type=modality.mime_type - ) - chunks.append(chunk) - except lf.ModalityError as e: - raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}') from e - else: - raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}') - return chunks - - def _response_to_result( - self, response: GenerateContentResponse | pg.Dict - ) -> lf.LMSamplingResult: - """Parses generative response into message.""" - samples = [] - for candidate in response.candidates: - chunks = [] - for part in candidate.content.parts: - # TODO(daiyip): support multi-modal parts when they are available via - # Gemini API. - if hasattr(part, 'text'): - chunks.append(part.text) - samples.append(lf.LMSample(lf.AIMessage.from_chunks(chunks), score=0.0)) - return lf.LMSamplingResult(samples) - - def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]: - assert self._api_initialized, 'Vertex AI API is not initialized.' - return self._parallel_execute_with_currency_control( - self._sample_single, - prompts, - ) - def _sample_single(self, prompt: lf.Message) -> lf.LMSamplingResult: - """Samples a single prompt.""" - model = _GOOGLE_GENAI_MODEL_HUB.get(self.model) - input_content = self._content_from_message(prompt) - response = model.generate_content( - input_content, - generation_config=self._generation_config(self.sampling_options), - ) - return self._response_to_result(response) - - -class _LegacyGenerativeModel(pg.Object): - """Base for legacy GenAI generative model.""" - - model: str - - def generate_content( - self, - input_content: list[str | BlobDict], - generation_config: GenerationConfig, - ) -> pg.Dict: - """Generate content.""" - segments = [] - for s in input_content: - if not isinstance(s, str): - raise ValueError(f'Unsupported modality: {s!r}') - segments.append(s) - return self.generate(' '.join(segments), generation_config) - - @abc.abstractmethod - def generate( - self, prompt: str, generation_config: GenerationConfig) -> pg.Dict: - """Generate response based on prompt.""" - - -class _LegacyCompletionModel(_LegacyGenerativeModel): - """Legacy GenAI completion model.""" - - def generate( - self, prompt: str, generation_config: GenerationConfig - ) -> pg.Dict: - assert genai is not None - completion: Completion = genai.generate_text( - model=f'models/{self.model}', - prompt=prompt, - temperature=generation_config.temperature, - top_k=generation_config.top_k, - top_p=generation_config.top_p, - candidate_count=generation_config.candidate_count, - max_output_tokens=generation_config.max_output_tokens, - stop_sequences=generation_config.stop_sequences, - ) - return pg.Dict( - candidates=[ - pg.Dict(content=pg.Dict(parts=[pg.Dict(text=c['output'])])) - for c in completion.candidates - ] - ) - - -class _LegacyChatModel(_LegacyGenerativeModel): - """Legacy GenAI chat model.""" - - def generate( - self, prompt: str, generation_config: GenerationConfig - ) -> pg.Dict: - assert genai is not None - response: ChatResponse = genai.chat( - model=f'models/{self.model}', - messages=prompt, - temperature=generation_config.temperature, - top_k=generation_config.top_k, - top_p=generation_config.top_p, - candidate_count=generation_config.candidate_count, - ) - return pg.Dict( - candidates=[ - pg.Dict(content=pg.Dict(parts=[pg.Dict(text=c['content'])])) - for c in response.candidates - ] - ) - - -class _ModelHub: - """Google Generative AI model hub.""" - - def __init__(self): - self._model_cache = {} - - def get( - self, model_name: str - ) -> GenerativeModel | _LegacyGenerativeModel: - """Gets a generative model by model id.""" - assert genai is not None - model = self._model_cache.get(model_name, None) - if model is None: - model_info = genai.get_model(f'models/{model_name}') - if 'generateContent' in model_info.supported_generation_methods: - model = genai.GenerativeModel(model_name) - elif 'generateText' in model_info.supported_generation_methods: - model = _LegacyCompletionModel(model_name) - elif 'generateMessage' in model_info.supported_generation_methods: - model = _LegacyChatModel(model_name) - else: - raise ValueError(f'Unsupported model: {model_name!r}') - self._model_cache[model_name] = model - return model - - -_GOOGLE_GENAI_MODEL_HUB = _ModelHub() - - -# -# Public Gemini models. -# -class GeminiFlash2_0ThinkingExp(GenAI): # pylint: disable=invalid-name - """Gemini 2.0 Flash Thinking Experimental model.""" +class GeminiFlash2_0ThinkingExp_20241219(GenAI): # pylint: disable=invalid-name + """Gemini Flash 2.0 Thinking model launched on 12/19/2024.""" + api_version = 'v1alpha' model = 'gemini-2.0-flash-thinking-exp-1219' - supported_modalities = ( - vertexai.DOCUMENT_TYPES - + vertexai.IMAGE_TYPES - + vertexai.AUDIO_TYPES - + vertexai.VIDEO_TYPES - ) class GeminiFlash2_0Exp(GenAI): # pylint: disable=invalid-name - """Gemini Experimental model launched on 12/06/2024.""" + """Gemini Flash 2.0 model launched on 12/11/2024.""" model = 'gemini-2.0-flash-exp' - supported_modalities = ( - vertexai.DOCUMENT_TYPES - + vertexai.IMAGE_TYPES - + vertexai.AUDIO_TYPES - + vertexai.VIDEO_TYPES - ) class GeminiExp_20241206(GenAI): # pylint: disable=invalid-name """Gemini Experimental model launched on 12/06/2024.""" model = 'gemini-exp-1206' - supported_modalities = ( - vertexai.DOCUMENT_TYPES - + vertexai.IMAGE_TYPES - + vertexai.AUDIO_TYPES - + vertexai.VIDEO_TYPES - ) class GeminiExp_20241114(GenAI): # pylint: disable=invalid-name """Gemini Experimental model launched on 11/14/2024.""" model = 'gemini-exp-1114' - supported_modalities = ( - vertexai.DOCUMENT_TYPES - + vertexai.IMAGE_TYPES - + vertexai.AUDIO_TYPES - + vertexai.VIDEO_TYPES - ) class GeminiPro1_5(GenAI): # pylint: disable=invalid-name """Gemini Pro latest model.""" model = 'gemini-1.5-pro-latest' - supported_modalities = ( - vertexai.DOCUMENT_TYPES - + vertexai.IMAGE_TYPES - + vertexai.AUDIO_TYPES - + vertexai.VIDEO_TYPES - ) -class GeminiFlash1_5(GenAI): # pylint: disable=invalid-name - """Gemini Flash latest model.""" +class GeminiPro1_5_002(GenAI): # pylint: disable=invalid-name + """Gemini Pro latest model.""" + + model = 'gemini-1.5-pro-002' - model = 'gemini-1.5-flash-latest' - supported_modalities = ( - vertexai.DOCUMENT_TYPES - + vertexai.IMAGE_TYPES - + vertexai.AUDIO_TYPES - + vertexai.VIDEO_TYPES - ) +class GeminiPro1_5_001(GenAI): # pylint: disable=invalid-name + """Gemini Pro latest model.""" -class GeminiPro(GenAI): - """Gemini Pro model.""" + model = 'gemini-1.5-pro-001' - model = 'gemini-pro' + +class GeminiFlash1_5(GenAI): # pylint: disable=invalid-name + """Gemini Flash latest model.""" + + model = 'gemini-1.5-flash-latest' -class GeminiProVision(GenAI): - """Gemini Pro vision model.""" +class GeminiFlash1_5_002(GenAI): # pylint: disable=invalid-name + """Gemini Flash 1.5 model stable version 002.""" - model = 'gemini-pro-vision' - supported_modalities = vertexai.IMAGE_TYPES + vertexai.VIDEO_TYPES + model = 'gemini-1.5-flash-002' -class Palm2(GenAI): - """PaLM2 model.""" +class GeminiFlash1_5_001(GenAI): # pylint: disable=invalid-name + """Gemini Flash 1.5 model stable version 001.""" - model = 'text-bison-001' + model = 'gemini-1.5-flash-001' -class Palm2_IT(GenAI): # pylint: disable=invalid-name - """PaLM2 instruction-tuned model.""" +class GeminiPro1(GenAI): # pylint: disable=invalid-name + """Gemini 1.0 Pro model.""" - model = 'chat-bison-001' + model = 'gemini-1.0-pro' diff --git a/langfun/core/llms/google_genai_test.py b/langfun/core/llms/google_genai_test.py index 882f7a80..879e6a6c 100644 --- a/langfun/core/llms/google_genai_test.py +++ b/langfun/core/llms/google_genai_test.py @@ -11,223 +11,28 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for Gemini models.""" +"""Tests for Google GenAI models.""" import os import unittest -from unittest import mock - -from google import generativeai as genai -import langfun.core as lf -from langfun.core import modalities as lf_modalities from langfun.core.llms import google_genai -import pyglove as pg - - -example_image = ( - b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x18\x00\x00\x00\x18\x04' - b'\x03\x00\x00\x00\x12Y \xcb\x00\x00\x00\x18PLTE\x00\x00' - b'\x00fff_chaag_cg_ch^ci_ciC\xedb\x94\x00\x00\x00\x08tRNS' - b'\x00\n\x9f*\xd4\xff_\xf4\xe4\x8b\xf3a\x00\x00\x00>IDATx' - b'\x01c \x05\x08)"\xd8\xcc\xae!\x06pNz\x88k\x19\\Q\xa8"\x10' - b'\xc1\x14\x95\x01%\xc1\n\xa143Ta\xa8"D-\x84\x03QM\x98\xc3' - b'\x1a\x1a\x1a@5\x0e\x04\xa0q\x88\x05\x00\x07\xf8\x18\xf9' - b'\xdao\xd0|\x00\x00\x00\x00IEND\xaeB`\x82' -) - - -def mock_get_model(model_name, *args, **kwargs): - del args, kwargs - if 'gemini' in model_name: - method = 'generateContent' - elif 'chat' in model_name: - method = 'generateMessage' - else: - method = 'generateText' - return pg.Dict(supported_generation_methods=[method]) - - -def mock_generate_text(*, model, prompt, **kwargs): - return pg.Dict( - candidates=[pg.Dict(output=f'{prompt} to {model} with {kwargs}')] - ) - - -def mock_chat(*, model, messages, **kwargs): - return pg.Dict( - candidates=[pg.Dict(content=f'{messages} to {model} with {kwargs}')] - ) - - -def mock_generate_content(content, generation_config, **kwargs): - del kwargs - c = generation_config - return genai.types.GenerateContentResponse( - done=True, - iterator=None, - chunks=[], - result=pg.Dict( - prompt_feedback=pg.Dict(block_reason=None), - candidates=[ - pg.Dict( - content=pg.Dict( - parts=[ - pg.Dict( - text=( - f'This is a response to {content[0]} with ' - f'n={c.candidate_count}, ' - f'temperature={c.temperature}, ' - f'top_p={c.top_p}, ' - f'top_k={c.top_k}, ' - f'max_tokens={c.max_output_tokens}, ' - f'stop={c.stop_sequences}.' - ) - ) - ] - ), - ), - ], - ), - ) class GenAITest(unittest.TestCase): - """Tests for Google GenAI model.""" - - def test_content_from_message_text_only(self): - text = 'This is a beautiful day' - model = google_genai.GeminiPro() - chunks = model._content_from_message(lf.UserMessage(text)) - self.assertEqual(chunks, [text]) - - def test_content_from_message_mm(self): - message = lf.UserMessage( - 'This is an <<[[image]]>>, what is it?', - image=lf_modalities.Image.from_bytes(example_image), - ) + """Tests for GenAI model.""" - # Non-multimodal model. - with self.assertRaisesRegex(lf.ModalityError, 'Unsupported modality'): - google_genai.GeminiPro()._content_from_message(message) - - model = google_genai.GeminiProVision() - chunks = model._content_from_message(message) - self.maxDiff = None - self.assertEqual( - chunks, - [ - 'This is an', - genai.types.BlobDict(mime_type='image/png', data=example_image), - ', what is it?', - ], - ) - - def test_response_to_result_text_only(self): - response = genai.types.GenerateContentResponse( - done=True, - iterator=None, - chunks=[], - result=pg.Dict( - prompt_feedback=pg.Dict(block_reason=None), - candidates=[ - pg.Dict( - content=pg.Dict( - parts=[pg.Dict(text='This is response 1.')] - ), - ), - pg.Dict( - content=pg.Dict(parts=[pg.Dict(text='This is response 2.')]) - ), - ], - ), - ) - model = google_genai.GeminiProVision() - result = model._response_to_result(response) - self.assertEqual( - result, - lf.LMSamplingResult([ - lf.LMSample(lf.AIMessage('This is response 1.'), score=0.0), - lf.LMSample(lf.AIMessage('This is response 2.'), score=0.0), - ]), - ) - - def test_model_hub(self): - orig_get_model = genai.get_model - genai.get_model = mock_get_model - - model = google_genai._GOOGLE_GENAI_MODEL_HUB.get('gemini-pro') - self.assertIsNotNone(model) - self.assertIs(google_genai._GOOGLE_GENAI_MODEL_HUB.get('gemini-pro'), model) - - genai.get_model = orig_get_model - - def test_api_key_check(self): + def test_basics(self): with self.assertRaisesRegex(ValueError, 'Please specify `api_key`'): - _ = google_genai.GeminiPro()._api_initialized + _ = google_genai.GeminiPro1_5().api_endpoint + + self.assertIsNotNone(google_genai.GeminiPro1_5(api_key='abc').api_endpoint) - self.assertTrue(google_genai.GeminiPro(api_key='abc')._api_initialized) os.environ['GOOGLE_API_KEY'] = 'abc' - self.assertTrue(google_genai.GeminiPro()._api_initialized) + lm = google_genai.GeminiPro1_5() + self.assertIsNotNone(lm.api_endpoint) + self.assertTrue(lm.model_id.startswith('GenAI(')) del os.environ['GOOGLE_API_KEY'] - def test_call(self): - with mock.patch( - 'google.generativeai.GenerativeModel.generate_content', - ) as mock_generate: - orig_get_model = genai.get_model - genai.get_model = mock_get_model - mock_generate.side_effect = mock_generate_content - - lm = google_genai.GeminiPro(api_key='test_key') - self.maxDiff = None - self.assertEqual( - lm('hello', temperature=2.0, top_k=20, max_tokens=1024).text, - ( - 'This is a response to hello with n=1, temperature=2.0, ' - 'top_p=None, top_k=20, max_tokens=1024, stop=None.' - ), - ) - genai.get_model = orig_get_model - - def test_call_with_legacy_completion_model(self): - orig_get_model = genai.get_model - genai.get_model = mock_get_model - orig_generate_text = getattr(genai, 'generate_text', None) - if orig_generate_text is not None: - genai.generate_text = mock_generate_text - - lm = google_genai.Palm2(api_key='test_key') - self.maxDiff = None - self.assertEqual( - lm('hello', temperature=2.0, top_k=20).text, - ( - "hello to models/text-bison-001 with {'temperature': 2.0, " - "'top_k': 20, 'top_p': None, 'candidate_count': 1, " - "'max_output_tokens': None, 'stop_sequences': None}" - ), - ) - genai.generate_text = orig_generate_text - genai.get_model = orig_get_model - - def test_call_with_legacy_chat_model(self): - orig_get_model = genai.get_model - genai.get_model = mock_get_model - orig_chat = getattr(genai, 'chat', None) - if orig_chat is not None: - genai.chat = mock_chat - - lm = google_genai.Palm2_IT(api_key='test_key') - self.maxDiff = None - self.assertEqual( - lm('hello', temperature=2.0, top_k=20).text, - ( - "hello to models/chat-bison-001 with {'temperature': 2.0, " - "'top_k': 20, 'top_p': None, 'candidate_count': 1}" - ), - ) - genai.chat = orig_chat - genai.get_model = orig_get_model - if __name__ == '__main__': unittest.main() diff --git a/langfun/core/llms/vertexai.py b/langfun/core/llms/vertexai.py index 24aac4cc..fd4c698d 100644 --- a/langfun/core/llms/vertexai.py +++ b/langfun/core/llms/vertexai.py @@ -13,14 +13,12 @@ # limitations under the License. """Vertex AI generative models.""" -import base64 import functools import os from typing import Annotated, Any import langfun.core as lf -from langfun.core import modalities as lf_modalities -from langfun.core.llms import rest +from langfun.core.llms import gemini import pyglove as pg try: @@ -38,114 +36,11 @@ Credentials = Any -# https://cloud.google.com/vertex-ai/generative-ai/pricing -# describes that the average number of characters per token is about 4. -AVGERAGE_CHARS_PER_TOKEN = 4 - - -# Price in US dollars, -# from https://cloud.google.com/vertex-ai/generative-ai/pricing -# as of 2024-10-10. -SUPPORTED_MODELS_AND_SETTINGS = { - 'gemini-1.5-pro-001': pg.Dict( - rpm=100, - cost_per_1k_input_chars=0.0003125, - cost_per_1k_output_chars=0.00125, - ), - 'gemini-1.5-pro-002': pg.Dict( - rpm=100, - cost_per_1k_input_chars=0.0003125, - cost_per_1k_output_chars=0.00125, - ), - 'gemini-1.5-flash-002': pg.Dict( - rpm=500, - cost_per_1k_input_chars=0.00001875, - cost_per_1k_output_chars=0.000075, - ), - 'gemini-1.5-flash-001': pg.Dict( - rpm=500, - cost_per_1k_input_chars=0.00001875, - cost_per_1k_output_chars=0.000075, - ), - 'gemini-1.5-pro': pg.Dict( - rpm=100, - cost_per_1k_input_chars=0.0003125, - cost_per_1k_output_chars=0.00125, - ), - 'gemini-1.5-flash': pg.Dict( - rpm=500, - cost_per_1k_input_chars=0.00001875, - cost_per_1k_output_chars=0.000075, - ), - 'gemini-1.5-pro-preview-0514': pg.Dict( - rpm=50, - cost_per_1k_input_chars=0.0003125, - cost_per_1k_output_chars=0.00125, - ), - 'gemini-1.5-pro-preview-0409': pg.Dict( - rpm=50, - cost_per_1k_input_chars=0.0003125, - cost_per_1k_output_chars=0.00125, - ), - 'gemini-1.5-flash-preview-0514': pg.Dict( - rpm=200, - cost_per_1k_input_chars=0.00001875, - cost_per_1k_output_chars=0.000075, - ), - 'gemini-1.0-pro': pg.Dict( - rpm=300, - cost_per_1k_input_chars=0.000125, - cost_per_1k_output_chars=0.000375, - ), - 'gemini-1.0-pro-vision': pg.Dict( - rpm=100, - cost_per_1k_input_chars=0.000125, - cost_per_1k_output_chars=0.000375, - ), - # TODO(sharatsharat): Update costs when published - 'gemini-exp-1206': pg.Dict( - rpm=20, - cost_per_1k_input_chars=0.000, - cost_per_1k_output_chars=0.000, - ), - # TODO(sharatsharat): Update costs when published - 'gemini-2.0-flash-exp': pg.Dict( - rpm=10, - cost_per_1k_input_chars=0.000, - cost_per_1k_output_chars=0.000, - ), - # TODO(yifenglu): Update costs when published - 'gemini-2.0-flash-thinking-exp-1219': pg.Dict( - rpm=10, - cost_per_1k_input_chars=0.000, - cost_per_1k_output_chars=0.000, - ), - # TODO(chengrun): Set a more appropriate rpm for endpoint. - 'vertexai-endpoint': pg.Dict( - rpm=20, - cost_per_1k_input_chars=0.0000125, - cost_per_1k_output_chars=0.0000375, - ), -} - - @lf.use_init_args(['model']) @pg.members([('api_endpoint', pg.typing.Str().freeze(''))]) -class VertexAI(rest.REST): +class VertexAI(gemini.Gemini): """Language model served on VertexAI with REST API.""" - model: pg.typing.Annotated[ - pg.typing.Enum( - pg.MISSING_VALUE, list(SUPPORTED_MODELS_AND_SETTINGS.keys()) - ), - ( - 'Vertex AI model name with REST API support. See ' - 'https://cloud.google.com/vertex-ai/generative-ai/docs/' - 'model-reference/inference#supported-models' - ' for details.' - ), - ] - project: Annotated[ str | None, ( @@ -170,11 +65,6 @@ class VertexAI(rest.REST): ), ] = None - supported_modalities: Annotated[ - list[str], - 'A list of MIME types for supported modalities' - ] = [] - def _on_bound(self): super()._on_bound() if google_auth is None: @@ -209,31 +99,9 @@ def _initialize(self): self._credentials = credentials @property - def max_concurrency(self) -> int: - """Returns the maximum number of concurrent requests.""" - return self.rate_to_max_concurrency( - requests_per_min=SUPPORTED_MODELS_AND_SETTINGS[self.model].rpm, - tokens_per_min=0, - ) - - def estimate_cost( - self, - num_input_tokens: int, - num_output_tokens: int - ) -> float | None: - """Estimate the cost based on usage.""" - cost_per_1k_input_chars = SUPPORTED_MODELS_AND_SETTINGS[self.model].get( - 'cost_per_1k_input_chars', None - ) - cost_per_1k_output_chars = SUPPORTED_MODELS_AND_SETTINGS[self.model].get( - 'cost_per_1k_output_chars', None - ) - if cost_per_1k_output_chars is None or cost_per_1k_input_chars is None: - return None - return ( - cost_per_1k_input_chars * num_input_tokens - + cost_per_1k_output_chars * num_output_tokens - ) * AVGERAGE_CHARS_PER_TOKEN / 1000 + def model_id(self) -> str: + """Returns a string to identify the model.""" + return f'VertexAI({self.model})' @functools.cached_property def _session(self): @@ -244,12 +112,6 @@ def _session(self): s.headers.update(self.headers or {}) return s - @property - def headers(self): - return { - 'Content-Type': 'application/json; charset=utf-8', - } - @property def api_endpoint(self) -> str: return ( @@ -258,263 +120,69 @@ def api_endpoint(self) -> str: f'models/{self.model}:generateContent' ) - def request( - self, prompt: lf.Message, sampling_options: lf.LMSamplingOptions - ) -> dict[str, Any]: - request = dict( - generationConfig=self._generation_config(prompt, sampling_options) - ) - request['contents'] = [self._content_from_message(prompt)] - return request - - def _generation_config( - self, prompt: lf.Message, options: lf.LMSamplingOptions - ) -> dict[str, Any]: - """Returns a dict as generation config for prompt and LMSamplingOptions.""" - config = dict( - temperature=options.temperature, - maxOutputTokens=options.max_tokens, - candidateCount=options.n, - topK=options.top_k, - topP=options.top_p, - stopSequences=options.stop, - seed=options.random_seed, - responseLogprobs=options.logprobs, - logprobs=options.top_logprobs, - ) - if json_schema := prompt.metadata.get('json_schema'): - if not isinstance(json_schema, dict): - raise ValueError( - f'`json_schema` must be a dict, got {json_schema!r}.' - ) - json_schema = pg.to_json(json_schema) - config['responseSchema'] = json_schema - config['responseMimeType'] = 'application/json' - prompt.metadata.formatted_text = ( - prompt.text - + '\n\n [RESPONSE FORMAT (not part of prompt)]\n' - + pg.to_json_str(json_schema, json_indent=2) - ) - return config - - def _content_from_message(self, prompt: lf.Message) -> dict[str, Any]: - """Gets generation content from langfun message.""" - parts = [] - for lf_chunk in prompt.chunk(): - if isinstance(lf_chunk, str): - parts.append({'text': lf_chunk}) - elif isinstance(lf_chunk, lf_modalities.Mime): - try: - modalities = lf_chunk.make_compatible( - self.supported_modalities + ['text/plain'] - ) - if isinstance(modalities, lf_modalities.Mime): - modalities = [modalities] - for modality in modalities: - if modality.is_text: - parts.append({'text': modality.to_text()}) - else: - parts.append({ - 'inlineData': { - 'data': base64.b64encode(modality.to_bytes()).decode(), - 'mimeType': modality.mime_type, - } - }) - except lf.ModalityError as e: - raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}') from e - else: - raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}') - return dict(role='user', parts=parts) - - def result(self, json: dict[str, Any]) -> lf.LMSamplingResult: - messages = [ - self._message_from_content_parts(candidate['content']['parts']) - for candidate in json['candidates'] - ] - usage = json['usageMetadata'] - input_tokens = usage['promptTokenCount'] - output_tokens = usage['candidatesTokenCount'] - return lf.LMSamplingResult( - [lf.LMSample(message) for message in messages], - usage=lf.LMSamplingUsage( - prompt_tokens=input_tokens, - completion_tokens=output_tokens, - total_tokens=input_tokens + output_tokens, - estimated_cost=self.estimate_cost( - num_input_tokens=input_tokens, - num_output_tokens=output_tokens, - ), - ), - ) +class VertexAIGeminiFlash2_0ThinkingExp_20241219(VertexAI): # pylint: disable=invalid-name + """Vertex AI Gemini Flash 2.0 Thinking model launched on 12/19/2024.""" - def _message_from_content_parts( - self, parts: list[dict[str, Any]] - ) -> lf.Message: - """Converts Vertex AI's content parts protocol to message.""" - chunks = [] - for part in parts: - if text_part := part.get('text'): - chunks.append(text_part) - else: - raise ValueError(f'Unsupported part: {part}') - return lf.AIMessage.from_chunks(chunks) - - -IMAGE_TYPES = [ - 'image/png', - 'image/jpeg', - 'image/webp', - 'image/heic', - 'image/heif', -] - -AUDIO_TYPES = [ - 'audio/aac', - 'audio/flac', - 'audio/mp3', - 'audio/m4a', - 'audio/mpeg', - 'audio/mpga', - 'audio/mp4', - 'audio/opus', - 'audio/pcm', - 'audio/wav', - 'audio/webm', -] - -VIDEO_TYPES = [ - 'video/mov', - 'video/mpeg', - 'video/mpegps', - 'video/mpg', - 'video/mp4', - 'video/webm', - 'video/wmv', - 'video/x-flv', - 'video/3gpp', - 'video/quicktime', -] - -DOCUMENT_TYPES = [ - 'application/pdf', - 'text/plain', - 'text/csv', - 'text/html', - 'text/xml', - 'text/x-script.python', - 'application/json', -] - - -class VertexAIGemini2_0(VertexAI): # pylint: disable=invalid-name - """Vertex AI Gemini 2.0 model.""" - - supported_modalities: pg.typing.List(str).freeze( # pytype: disable=invalid-annotation - DOCUMENT_TYPES + IMAGE_TYPES + AUDIO_TYPES + VIDEO_TYPES - ) - - -class VertexAIGeminiFlash2_0Exp(VertexAIGemini2_0): # pylint: disable=invalid-name + api_version = 'v1alpha' + model = 'gemini-2.0-flash-thinking-exp-1219' + + +class VertexAIGeminiFlash2_0Exp(VertexAI): # pylint: disable=invalid-name """Vertex AI Gemini 2.0 Flash model.""" model = 'gemini-2.0-flash-exp' -class VertexAIGeminiFlash2_0ThinkingExp(VertexAIGemini2_0): # pylint: disable=invalid-name - """Vertex AI Gemini 2.0 Flash model.""" +class VertexAIGeminiExp_20241206(VertexAI): # pylint: disable=invalid-name + """Vertex AI Gemini Experimental model launched on 12/06/2024.""" - model = 'gemini-2.0-flash-thinking-exp-1219' + model = 'gemini-exp-1206' -class VertexAIGemini1_5(VertexAI): # pylint: disable=invalid-name - """Vertex AI Gemini 1.5 model.""" +class VertexAIGeminiExp_20241114(VertexAI): # pylint: disable=invalid-name + """Vertex AI Gemini Experimental model launched on 11/14/2024.""" - supported_modalities: pg.typing.List(str).freeze( # pytype: disable=invalid-annotation - DOCUMENT_TYPES + IMAGE_TYPES + AUDIO_TYPES + VIDEO_TYPES - ) + model = 'gemini-exp-1114' -class VertexAIGeminiPro1_5(VertexAIGemini1_5): # pylint: disable=invalid-name +class VertexAIGeminiPro1_5(VertexAI): # pylint: disable=invalid-name """Vertex AI Gemini 1.5 Pro model.""" - model = 'gemini-1.5-pro' + model = 'gemini-1.5-pro-latest' -class VertexAIGeminiPro1_5_002(VertexAIGemini1_5): # pylint: disable=invalid-name +class VertexAIGeminiPro1_5_002(VertexAI): # pylint: disable=invalid-name """Vertex AI Gemini 1.5 Pro model.""" model = 'gemini-1.5-pro-002' -class VertexAIGeminiPro1_5_001(VertexAIGemini1_5): # pylint: disable=invalid-name +class VertexAIGeminiPro1_5_001(VertexAI): # pylint: disable=invalid-name """Vertex AI Gemini 1.5 Pro model.""" model = 'gemini-1.5-pro-001' -class VertexAIGeminiPro1_5_0514(VertexAIGemini1_5): # pylint: disable=invalid-name - """Vertex AI Gemini 1.5 Pro preview model.""" - - model = 'gemini-1.5-pro-preview-0514' - - -class VertexAIGeminiPro1_5_0409(VertexAIGemini1_5): # pylint: disable=invalid-name - """Vertex AI Gemini 1.5 Pro preview model.""" - - model = 'gemini-1.5-pro-preview-0409' - - -class VertexAIGeminiFlash1_5(VertexAIGemini1_5): # pylint: disable=invalid-name +class VertexAIGeminiFlash1_5(VertexAI): # pylint: disable=invalid-name """Vertex AI Gemini 1.5 Flash model.""" model = 'gemini-1.5-flash' -class VertexAIGeminiFlash1_5_002(VertexAIGemini1_5): # pylint: disable=invalid-name +class VertexAIGeminiFlash1_5_002(VertexAI): # pylint: disable=invalid-name """Vertex AI Gemini 1.5 Flash model.""" model = 'gemini-1.5-flash-002' -class VertexAIGeminiFlash1_5_001(VertexAIGemini1_5): # pylint: disable=invalid-name +class VertexAIGeminiFlash1_5_001(VertexAI): # pylint: disable=invalid-name """Vertex AI Gemini 1.5 Flash model.""" model = 'gemini-1.5-flash-001' -class VertexAIGeminiFlash1_5_0514(VertexAIGemini1_5): # pylint: disable=invalid-name - """Vertex AI Gemini 1.5 Flash preview model.""" - - model = 'gemini-1.5-flash-preview-0514' - - class VertexAIGeminiPro1(VertexAI): # pylint: disable=invalid-name """Vertex AI Gemini 1.0 Pro model.""" model = 'gemini-1.0-pro' - - -class VertexAIGeminiPro1Vision(VertexAI): # pylint: disable=invalid-name - """Vertex AI Gemini 1.0 Pro Vision model.""" - - model = 'gemini-1.0-pro-vision' - supported_modalities: pg.typing.List(str).freeze( # pytype: disable=invalid-annotation - IMAGE_TYPES + VIDEO_TYPES - ) - - -class VertexAIEndpoint(VertexAI): # pylint: disable=invalid-name - """Vertex AI Endpoint model.""" - - model = 'vertexai-endpoint' - - endpoint: Annotated[str, 'Vertex AI Endpoint ID.'] - - @property - def api_endpoint(self) -> str: - return ( - f'https://{self.location}-aiplatform.googleapis.com/v1/projects/' - f'{self.project}/locations/{self.location}/' - f'endpoints/{self.endpoint}:generateContent' - ) diff --git a/langfun/core/llms/vertexai_test.py b/langfun/core/llms/vertexai_test.py index 3e9d1cac..3e7989ee 100644 --- a/langfun/core/llms/vertexai_test.py +++ b/langfun/core/llms/vertexai_test.py @@ -11,105 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for Gemini models.""" +"""Tests for VertexAI models.""" -import base64 import os -from typing import Any import unittest from unittest import mock -import langfun.core as lf -from langfun.core import modalities as lf_modalities from langfun.core.llms import vertexai -import pyglove as pg -import requests - - -example_image = ( - b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x18\x00\x00\x00\x18\x04' - b'\x03\x00\x00\x00\x12Y \xcb\x00\x00\x00\x18PLTE\x00\x00' - b'\x00fff_chaag_cg_ch^ci_ciC\xedb\x94\x00\x00\x00\x08tRNS' - b'\x00\n\x9f*\xd4\xff_\xf4\xe4\x8b\xf3a\x00\x00\x00>IDATx' - b'\x01c \x05\x08)"\xd8\xcc\xae!\x06pNz\x88k\x19\\Q\xa8"\x10' - b'\xc1\x14\x95\x01%\xc1\n\xa143Ta\xa8"D-\x84\x03QM\x98\xc3' - b'\x1a\x1a\x1a@5\x0e\x04\xa0q\x88\x05\x00\x07\xf8\x18\xf9' - b'\xdao\xd0|\x00\x00\x00\x00IEND\xaeB`\x82' -) - - -def mock_requests_post(url: str, json: dict[str, Any], **kwargs): - del url, kwargs - c = pg.Dict(json['generationConfig']) - content = json['contents'][0]['parts'][0]['text'] - response = requests.Response() - response.status_code = 200 - response._content = pg.to_json_str({ - 'candidates': [ - { - 'content': { - 'role': 'model', - 'parts': [ - { - 'text': ( - f'This is a response to {content} with ' - f'temperature={c.temperature}, ' - f'top_p={c.topP}, ' - f'top_k={c.topK}, ' - f'max_tokens={c.maxOutputTokens}, ' - f'stop={"".join(c.stopSequences)}.' - ) - }, - ], - }, - }, - ], - 'usageMetadata': { - 'promptTokenCount': 3, - 'candidatesTokenCount': 4, - } - }).encode() - return response class VertexAITest(unittest.TestCase): """Tests for Vertex model with REST API.""" - def test_content_from_message_text_only(self): - text = 'This is a beautiful day' - model = vertexai.VertexAIGeminiPro1_5_002() - chunks = model._content_from_message(lf.UserMessage(text)) - self.assertEqual(chunks, {'role': 'user', 'parts': [{'text': text}]}) - - def test_content_from_message_mm(self): - image = lf_modalities.Image.from_bytes(example_image) - message = lf.UserMessage( - 'This is an <<[[image]]>>, what is it?', image=image - ) - - # Non-multimodal model. - with self.assertRaisesRegex(lf.ModalityError, 'Unsupported modality'): - vertexai.VertexAIGeminiPro1()._content_from_message(message) - - model = vertexai.VertexAIGeminiPro1Vision() - content = model._content_from_message(message) - self.assertEqual( - content, - { - 'role': 'user', - 'parts': [ - {'text': 'This is an'}, - { - 'inlineData': { - 'data': base64.b64encode(example_image).decode(), - 'mimeType': 'image/png', - } - }, - {'text': ', what is it?'}, - ], - }, - ) - @mock.patch.object(vertexai.VertexAI, 'credentials', new=True) def test_project_and_location_check(self): with self.assertRaisesRegex(ValueError, 'Please specify `project`'): @@ -126,87 +39,14 @@ def test_project_and_location_check(self): os.environ['VERTEXAI_PROJECT'] = 'abc' os.environ['VERTEXAI_LOCATION'] = 'us-central1' - self.assertTrue(vertexai.VertexAIGeminiPro1()._api_initialized) + model = vertexai.VertexAIGeminiPro1() + self.assertTrue(model.model_id.startswith('VertexAI(')) + self.assertIsNotNone(model.api_endpoint) + self.assertTrue(model._api_initialized) + self.assertIsNotNone(model._session) del os.environ['VERTEXAI_PROJECT'] del os.environ['VERTEXAI_LOCATION'] - def test_generation_config(self): - model = vertexai.VertexAIGeminiPro1() - json_schema = { - 'type': 'object', - 'properties': { - 'name': {'type': 'string'}, - }, - 'required': ['name'], - 'title': 'Person', - } - actual = model._generation_config( - lf.UserMessage('hi', json_schema=json_schema), - lf.LMSamplingOptions( - temperature=2.0, - top_p=1.0, - top_k=20, - max_tokens=1024, - stop=['\n'], - ), - ) - self.assertEqual( - actual, - dict( - candidateCount=1, - temperature=2.0, - topP=1.0, - topK=20, - maxOutputTokens=1024, - stopSequences=['\n'], - responseLogprobs=False, - logprobs=None, - seed=None, - responseMimeType='application/json', - responseSchema={ - 'type': 'object', - 'properties': { - 'name': {'type': 'string'} - }, - 'required': ['name'], - 'title': 'Person', - } - ), - ) - with self.assertRaisesRegex( - ValueError, '`json_schema` must be a dict, got' - ): - model._generation_config( - lf.UserMessage('hi', json_schema='not a dict'), - lf.LMSamplingOptions(), - ) - - @mock.patch.object(vertexai.VertexAI, 'credentials', new=True) - def test_call_model(self): - with mock.patch('requests.Session.post') as mock_generate: - mock_generate.side_effect = mock_requests_post - - lm = vertexai.VertexAIGeminiPro1_5_002( - project='abc', location='us-central1' - ) - r = lm( - 'hello', - temperature=2.0, - top_p=1.0, - top_k=20, - max_tokens=1024, - stop='\n', - ) - self.assertEqual( - r.text, - ( - 'This is a response to hello with temperature=2.0, ' - 'top_p=1.0, top_k=20, max_tokens=1024, stop=\n.' - ), - ) - self.assertEqual(r.metadata.usage.prompt_tokens, 3) - self.assertEqual(r.metadata.usage.completion_tokens, 4) - if __name__ == '__main__': unittest.main() diff --git a/requirements.txt b/requirements.txt index 440c8c51..babae31e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,8 +6,8 @@ requests>=2.31.0 termcolor==1.1.0 tqdm>=4.64.1 -# extras:llm-google-genai -google-generativeai>=0.3.2 +# extras:vertexai +google-auth>=2.16.0 # extras:mime-auto python-magic>=0.4.27