diff --git a/README.md b/README.md index aaa04ea..553dce3 100644 --- a/README.md +++ b/README.md @@ -138,9 +138,7 @@ If you want to customize your installation, you can select specific features usi | Tag | Description | | ------------------- | ---------------------------------------- | | all | All Langfun features. | -| llm | All supported LLMs. | -| llm-google | All supported Google-powered LLMs. | -| llm-google-genai | LLMs powered by Google Generative AI API | +| vertexai | VertexAI access. | | mime | All MIME supports. | | mime-auto | Automatic MIME type detection. | | mime-docx | DocX format support. | @@ -149,9 +147,9 @@ If you want to customize your installation, you can select specific features usi | ui | UI enhancements | -For example, to install a nightly build that includes Google-powered LLMs, full modality support, and UI enhancements, use: +For example, to install a nightly build that includes VertexAI access, full modality support, and UI enhancements, use: ``` -pip install langfun[llm-google,mime,ui] --pre +pip install langfun[vertexai,mime,ui] --pre ``` *Disclaimer: this is not an officially supported Google product.* diff --git a/langfun/core/llms/__init__.py b/langfun/core/llms/__init__.py index 5e39028..84923dd 100644 --- a/langfun/core/llms/__init__.py +++ b/langfun/core/llms/__init__.py @@ -32,16 +32,30 @@ # Gemini models. from langfun.core.llms.google_genai import GenAI -from langfun.core.llms.google_genai import GeminiFlash2_0ThinkingExp +from langfun.core.llms.google_genai import GeminiFlash2_0ThinkingExp_20241219 from langfun.core.llms.google_genai import GeminiFlash2_0Exp -from langfun.core.llms.google_genai import GeminiExp_20241114 from langfun.core.llms.google_genai import GeminiExp_20241206 -from langfun.core.llms.google_genai import GeminiFlash1_5 +from langfun.core.llms.google_genai import GeminiExp_20241114 from langfun.core.llms.google_genai import GeminiPro1_5 -from langfun.core.llms.google_genai import GeminiPro -from langfun.core.llms.google_genai import GeminiProVision -from langfun.core.llms.google_genai import Palm2 -from langfun.core.llms.google_genai import Palm2_IT +from langfun.core.llms.google_genai import GeminiPro1_5_002 +from langfun.core.llms.google_genai import GeminiPro1_5_001 +from langfun.core.llms.google_genai import GeminiFlash1_5 +from langfun.core.llms.google_genai import GeminiFlash1_5_002 +from langfun.core.llms.google_genai import GeminiFlash1_5_001 +from langfun.core.llms.google_genai import GeminiPro1 + +from langfun.core.llms.vertexai import VertexAI +from langfun.core.llms.vertexai import VertexAIGeminiFlash2_0ThinkingExp_20241219 +from langfun.core.llms.vertexai import VertexAIGeminiFlash2_0Exp +from langfun.core.llms.vertexai import VertexAIGeminiExp_20241206 +from langfun.core.llms.vertexai import VertexAIGeminiExp_20241114 +from langfun.core.llms.vertexai import VertexAIGeminiPro1_5 +from langfun.core.llms.vertexai import VertexAIGeminiPro1_5_002 +from langfun.core.llms.vertexai import VertexAIGeminiPro1_5_001 +from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5 +from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5_002 +from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5_001 +from langfun.core.llms.vertexai import VertexAIGeminiPro1 # OpenAI models. from langfun.core.llms.openai import OpenAI @@ -124,25 +138,6 @@ from langfun.core.llms.groq import GroqWhisper_Large_v3 from langfun.core.llms.groq import GroqWhisper_Large_v3Turbo -from langfun.core.llms.vertexai import VertexAI -from langfun.core.llms.vertexai import VertexAIGemini2_0 -from langfun.core.llms.vertexai import VertexAIGeminiFlash2_0Exp -from langfun.core.llms.vertexai import VertexAIGeminiFlash2_0ThinkingExp -from langfun.core.llms.vertexai import VertexAIGemini1_5 -from langfun.core.llms.vertexai import VertexAIGeminiPro1_5 -from langfun.core.llms.vertexai import VertexAIGeminiPro1_5_001 -from langfun.core.llms.vertexai import VertexAIGeminiPro1_5_002 -from langfun.core.llms.vertexai import VertexAIGeminiPro1_5_0514 -from langfun.core.llms.vertexai import VertexAIGeminiPro1_5_0409 -from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5 -from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5_001 -from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5_002 -from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5_0514 -from langfun.core.llms.vertexai import VertexAIGeminiPro1 -from langfun.core.llms.vertexai import VertexAIGeminiPro1Vision -from langfun.core.llms.vertexai import VertexAIEndpoint - - # LLaMA C++ models. from langfun.core.llms.llama_cpp import LlamaCppRemote diff --git a/langfun/core/llms/gemini.py b/langfun/core/llms/gemini.py new file mode 100644 index 0000000..f5bbbda --- /dev/null +++ b/langfun/core/llms/gemini.py @@ -0,0 +1,500 @@ +# Copyright 2025 The Langfun Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Gemini REST API (Shared by Google GenAI and Vertex AI).""" + +import base64 +from typing import Any + +import langfun.core as lf +from langfun.core import modalities as lf_modalities +from langfun.core.llms import rest +import pyglove as pg + +# Supported modalities. + +IMAGE_TYPES = [ + 'image/png', + 'image/jpeg', + 'image/webp', + 'image/heic', + 'image/heif', +] + +AUDIO_TYPES = [ + 'audio/aac', + 'audio/flac', + 'audio/mp3', + 'audio/m4a', + 'audio/mpeg', + 'audio/mpga', + 'audio/mp4', + 'audio/opus', + 'audio/pcm', + 'audio/wav', + 'audio/webm', +] + +VIDEO_TYPES = [ + 'video/mov', + 'video/mpeg', + 'video/mpegps', + 'video/mpg', + 'video/mp4', + 'video/webm', + 'video/wmv', + 'video/x-flv', + 'video/3gpp', + 'video/quicktime', +] + +DOCUMENT_TYPES = [ + 'application/pdf', + 'text/plain', + 'text/csv', + 'text/html', + 'text/xml', + 'text/x-script.python', + 'application/json', +] + +TEXT_ONLY = [] + +ALL_MODALITIES = ( + IMAGE_TYPES + AUDIO_TYPES + VIDEO_TYPES + DOCUMENT_TYPES +) + +SUPPORTED_MODELS_AND_SETTINGS = { + # For automatically rate control and cost estimation, we explicitly register + # supported models here. This may be inconvenient for new models, but it + # helps us to keep track of the models and their pricing. + # Models and RPM are from + # https://ai.google.dev/gemini-api/docs/models/gemini?_gl=1*114hbho*_up*MQ..&gclid=Cj0KCQiAst67BhCEARIsAKKdWOljBY5aQdNQ41zOPkXFCwymUfMNFl_7ukm1veAf75ZTD9qWFrFr11IaApL3EALw_wcB + # Pricing in US dollars, from https://ai.google.dev/pricing + # as of 2025-01-03. + # NOTE: Please update google_genai.py, vertexai.py, __init__.py when + # adding new models. + # !!! PLEASE KEEP MODELS SORTED BY RELEASE DATE !!! + 'gemini-2.0-flash-thinking-exp-1219': pg.Dict( + latest_update='2024-12-19', + experimental=True, + in_service=True, + supported_modalities=ALL_MODALITIES, + rpm_free=10, + tpm_free=4_000_000, + rpm_paid=0, + tpm_paid=0, + cost_per_1m_input_tokens_up_to_128k=0, + cost_per_1m_output_tokens_up_to_128k=0, + cost_per_1m_cached_tokens_up_to_128k=0, + cost_per_1m_input_tokens_longer_than_128k=0, + cost_per_1m_output_tokens_longer_than_128k=0, + cost_per_1m_cached_tokens_longer_than_128k=0, + ), + 'gemini-2.0-flash-exp': pg.Dict( + latest_update='2024-12-11', + experimental=True, + in_service=True, + supported_modalities=ALL_MODALITIES, + rpm_free=10, + tpm_free=4_000_000, + rpm_paid=0, + tpm_paid=0, + cost_per_1m_input_tokens_up_to_128k=0, + cost_per_1m_output_tokens_up_to_128k=0, + cost_per_1m_cached_tokens_up_to_128k=0, + cost_per_1m_input_tokens_longer_than_128k=0, + cost_per_1m_output_tokens_longer_than_128k=0, + cost_per_1m_cached_tokens_longer_than_128k=0, + ), + 'gemini-exp-1206': pg.Dict( + latest_update='2024-12-06', + experimental=True, + in_service=True, + supported_modalities=ALL_MODALITIES, + rpm_free=10, + tpm_free=4_000_000, + rpm_paid=0, + tpm_paid=0, + cost_per_1m_input_tokens_up_to_128k=0, + cost_per_1m_output_tokens_up_to_128k=0, + cost_per_1m_cached_tokens_up_to_128k=0, + cost_per_1m_input_tokens_longer_than_128k=0, + cost_per_1m_output_tokens_longer_than_128k=0, + cost_per_1m_cached_tokens_longer_than_128k=0, + ), + 'learnlm-1.5-pro-experimental': pg.Dict( + latest_update='2024-11-19', + experimental=True, + in_service=True, + supported_modalities=ALL_MODALITIES, + rpm_free=10, + tpm_free=4_000_000, + rpm_paid=0, + tpm_paid=0, + cost_per_1m_input_tokens_up_to_128k=0, + cost_per_1m_output_tokens_up_to_128k=0, + cost_per_1m_cached_tokens_up_to_128k=0, + cost_per_1m_input_tokens_longer_than_128k=0, + cost_per_1m_output_tokens_longer_than_128k=0, + cost_per_1m_cached_tokens_longer_than_128k=0, + ), + 'gemini-exp-1114': pg.Dict( + latest_update='2024-11-14', + experimental=True, + in_service=True, + supported_modalities=ALL_MODALITIES, + rpm_free=10, + tpm_free=4_000_000, + rpm_paid=0, + tpm_paid=0, + cost_per_1m_input_tokens_up_to_128k=0, + cost_per_1m_output_tokens_up_to_128k=0, + cost_per_1m_cached_tokens_up_to_128k=0, + cost_per_1m_input_tokens_longer_than_128k=0, + cost_per_1m_output_tokens_longer_than_128k=0, + cost_per_1m_cached_tokens_longer_than_128k=0, + ), + 'gemini-1.5-flash-latest': pg.Dict( + latest_update='2024-09-30', + in_service=True, + supported_modalities=ALL_MODALITIES, + rpm_free=15, + tpm_free=1_000_000, + rpm_paid=2000, + tpm_paid=4_000_000, + cost_per_1m_input_tokens_up_to_128k=0.075, + cost_per_1m_output_tokens_up_to_128k=0.3, + cost_per_1m_cached_tokens_up_to_128k=0.01875, + cost_per_1m_input_tokens_longer_than_128k=0.15, + cost_per_1m_output_tokens_longer_than_128k=0.6, + cost_per_1m_cached_tokens_longer_than_128k=0.0375, + ), + 'gemini-1.5-flash': pg.Dict( + latest_update='2024-09-30', + in_service=True, + supported_modalities=ALL_MODALITIES, + rpm_free=15, + tpm_free=1_000_000, + rpm_paid=2000, + tpm_paid=4_000_000, + cost_per_1m_input_tokens_up_to_128k=0.075, + cost_per_1m_output_tokens_up_to_128k=0.3, + cost_per_1m_cached_tokens_up_to_128k=0.01875, + cost_per_1m_input_tokens_longer_than_128k=0.15, + cost_per_1m_output_tokens_longer_than_128k=0.6, + cost_per_1m_cached_tokens_longer_than_128k=0.0375, + ), + 'gemini-1.5-flash-001': pg.Dict( + latest_update='2024-09-30', + in_service=True, + supported_modalities=ALL_MODALITIES, + rpm_free=15, + tpm_free=1_000_000, + rpm_paid=2000, + tpm_paid=4_000_000, + cost_per_1m_input_tokens_up_to_128k=0.075, + cost_per_1m_output_tokens_up_to_128k=0.3, + cost_per_1m_cached_tokens_up_to_128k=0.01875, + cost_per_1m_input_tokens_longer_than_128k=0.15, + cost_per_1m_output_tokens_longer_than_128k=0.6, + cost_per_1m_cached_tokens_longer_than_128k=0.0375, + ), + 'gemini-1.5-flash-002': pg.Dict( + latest_update='2024-09-30', + in_service=True, + supported_modalities=ALL_MODALITIES, + rpm_free=15, + tpm_free=1_000_000, + rpm_paid=2000, + tpm_paid=4_000_000, + cost_per_1m_input_tokens_up_to_128k=0.075, + cost_per_1m_output_tokens_up_to_128k=0.3, + cost_per_1m_cached_tokens_up_to_128k=0.01875, + cost_per_1m_input_tokens_longer_than_128k=0.15, + cost_per_1m_output_tokens_longer_than_128k=0.6, + cost_per_1m_cached_tokens_longer_than_128k=0.0375, + ), + 'gemini-1.5-flash-8b': pg.Dict( + latest_update='2024-10-30', + in_service=True, + supported_modalities=ALL_MODALITIES, + rpm_free=15, + tpm_free=1_000_000, + rpm_paid=4000, + tpm_paid=4_000_000, + cost_per_1m_input_tokens_up_to_128k=0.0375, + cost_per_1m_output_tokens_up_to_128k=0.15, + cost_per_1m_cached_tokens_up_to_128k=0.01, + cost_per_1m_input_tokens_longer_than_128k=0.075, + cost_per_1m_output_tokens_longer_than_128k=0.3, + cost_per_1m_cached_tokens_longer_than_128k=0.02, + ), + 'gemini-1.5-flash-8b-001': pg.Dict( + latest_update='2024-10-30', + in_service=True, + supported_modalities=ALL_MODALITIES, + rpm_free=15, + tpm_free=1_000_000, + rpm_paid=4000, + tpm_paid=4_000_000, + cost_per_1m_input_tokens_up_to_128k=0.0375, + cost_per_1m_output_tokens_up_to_128k=0.15, + cost_per_1m_cached_tokens_up_to_128k=0.01, + cost_per_1m_input_tokens_longer_than_128k=0.075, + cost_per_1m_output_tokens_longer_than_128k=0.3, + cost_per_1m_cached_tokens_longer_than_128k=0.02, + ), + 'gemini-1.5-pro-latest': pg.Dict( + latest_update='2024-09-30', + in_service=True, + supported_modalities=ALL_MODALITIES, + rpm_free=2, + tpm_free=32_000, + rpm_paid=1000, + tpm_paid=4_000_000, + cost_per_1m_input_tokens_up_to_128k=1.25, + cost_per_1m_output_tokens_up_to_128k=5.00, + cost_per_1m_cached_tokens_up_to_128k=0.3125, + cost_per_1m_input_tokens_longer_than_128k=2.5, + cost_per_1m_output_tokens_longer_than_128k=10.00, + cost_per_1m_cached_tokens_longer_than_128k=0.625, + ), + 'gemini-1.5-pro': pg.Dict( + latest_update='2024-09-30', + in_service=True, + supported_modalities=ALL_MODALITIES, + rpm_free=2, + tpm_free=32_000, + rpm_paid=1000, + tpm_paid=4_000_000, + cost_per_1m_input_tokens_up_to_128k=1.25, + cost_per_1m_output_tokens_up_to_128k=5.00, + cost_per_1m_cached_tokens_up_to_128k=0.3125, + cost_per_1m_input_tokens_longer_than_128k=2.5, + cost_per_1m_output_tokens_longer_than_128k=10.00, + cost_per_1m_cached_tokens_longer_than_128k=0.625, + ), + 'gemini-1.5-pro-001': pg.Dict( + latest_update='2024-09-30', + in_service=True, + supported_modalities=ALL_MODALITIES, + rpm_free=2, + tpm_free=32_000, + rpm_paid=1000, + tpm_paid=4_000_000, + cost_per_1m_input_tokens_up_to_128k=1.25, + cost_per_1m_output_tokens_up_to_128k=5.00, + cost_per_1m_cached_tokens_up_to_128k=0.3125, + cost_per_1m_input_tokens_longer_than_128k=2.5, + cost_per_1m_output_tokens_longer_than_128k=10.00, + cost_per_1m_cached_tokens_longer_than_128k=0.625, + ), + 'gemini-1.5-pro-002': pg.Dict( + latest_update='2024-09-30', + in_service=True, + supported_modalities=ALL_MODALITIES, + rpm_free=2, + tpm_free=32_000, + rpm_paid=1000, + tpm_paid=4_000_000, + cost_per_1m_input_tokens_up_to_128k=1.25, + cost_per_1m_output_tokens_up_to_128k=5.00, + cost_per_1m_cached_tokens_up_to_128k=0.3125, + cost_per_1m_input_tokens_longer_than_128k=2.5, + cost_per_1m_output_tokens_longer_than_128k=10.00, + cost_per_1m_cached_tokens_longer_than_128k=0.625, + ), + 'gemini-1.0-pro': pg.Dict( + in_service=False, + supported_modalities=TEXT_ONLY, + rpm_free=15, + tpm_free=32_000, + rpm_paid=360, + tpm_paid=120_000, + cost_per_1m_input_tokens_up_to_128k=0.5, + cost_per_1m_output_tokens_up_to_128k=1.5, + cost_per_1m_cached_tokens_up_to_128k=0, + cost_per_1m_input_tokens_longer_than_128k=0.5, + cost_per_1m_output_tokens_longer_than_128k=1.5, + cost_per_1m_cached_tokens_longer_than_128k=0, + ), +} + + +@pg.use_init_args(['model']) +class Gemini(rest.REST): + """Language models provided by Google GenAI.""" + + model: pg.typing.Annotated[ + pg.typing.Enum( + pg.MISSING_VALUE, list(SUPPORTED_MODELS_AND_SETTINGS.keys()) + ), + 'The name of the model to use.', + ] + + @property + def supported_modalities(self) -> list[str]: + """Returns the list of supported modalities.""" + return SUPPORTED_MODELS_AND_SETTINGS[self.model].supported_modalities + + @property + def max_concurrency(self) -> int: + """Returns the maximum number of concurrent requests.""" + return self.rate_to_max_concurrency( + requests_per_min=max( + SUPPORTED_MODELS_AND_SETTINGS[self.model].rpm_free, + SUPPORTED_MODELS_AND_SETTINGS[self.model].rpm_paid + ), + tokens_per_min=max( + SUPPORTED_MODELS_AND_SETTINGS[self.model].tpm_free, + SUPPORTED_MODELS_AND_SETTINGS[self.model].tpm_paid, + ), + ) + + def estimate_cost( + self, + num_input_tokens: int, + num_output_tokens: int + ) -> float | None: + """Estimate the cost based on usage.""" + entry = SUPPORTED_MODELS_AND_SETTINGS[self.model] + if num_input_tokens < 128_000: + cost_per_1m_input_tokens = entry.cost_per_1m_input_tokens_up_to_128k + cost_per_1m_output_tokens = entry.cost_per_1m_output_tokens_up_to_128k + else: + cost_per_1m_input_tokens = entry.cost_per_1m_input_tokens_longer_than_128k + cost_per_1m_output_tokens = ( + entry.cost_per_1m_output_tokens_longer_than_128k + ) + return ( + cost_per_1m_input_tokens * num_input_tokens + + cost_per_1m_output_tokens * num_output_tokens + ) / 1000_1000 + + @property + def model_id(self) -> str: + """Returns a string to identify the model.""" + return self.model + + @classmethod + def dir(cls): + return [k for k, v in SUPPORTED_MODELS_AND_SETTINGS.items() if v.in_service] + + @property + def headers(self): + return { + 'Content-Type': 'application/json; charset=utf-8', + } + + def request( + self, prompt: lf.Message, sampling_options: lf.LMSamplingOptions + ) -> dict[str, Any]: + request = dict( + generationConfig=self._generation_config(prompt, sampling_options) + ) + request['contents'] = [self._content_from_message(prompt)] + return request + + def _generation_config( + self, prompt: lf.Message, options: lf.LMSamplingOptions + ) -> dict[str, Any]: + """Returns a dict as generation config for prompt and LMSamplingOptions.""" + config = dict( + temperature=options.temperature, + maxOutputTokens=options.max_tokens, + candidateCount=options.n, + topK=options.top_k, + topP=options.top_p, + stopSequences=options.stop, + seed=options.random_seed, + responseLogprobs=options.logprobs, + logprobs=options.top_logprobs, + ) + + if json_schema := prompt.metadata.get('json_schema'): + if not isinstance(json_schema, dict): + raise ValueError( + f'`json_schema` must be a dict, got {json_schema!r}.' + ) + json_schema = pg.to_json(json_schema) + config['responseSchema'] = json_schema + config['responseMimeType'] = 'application/json' + prompt.metadata.formatted_text = ( + prompt.text + + '\n\n [RESPONSE FORMAT (not part of prompt)]\n' + + pg.to_json_str(json_schema, json_indent=2) + ) + return config + + def _content_from_message(self, prompt: lf.Message) -> dict[str, Any]: + """Gets generation content from langfun message.""" + parts = [] + for lf_chunk in prompt.chunk(): + if isinstance(lf_chunk, str): + parts.append({'text': lf_chunk}) + elif isinstance(lf_chunk, lf_modalities.Mime): + try: + modalities = lf_chunk.make_compatible( + self.supported_modalities + ['text/plain'] + ) + if isinstance(modalities, lf_modalities.Mime): + modalities = [modalities] + for modality in modalities: + if modality.is_text: + parts.append({'text': modality.to_text()}) + else: + parts.append({ + 'inlineData': { + 'data': base64.b64encode(modality.to_bytes()).decode(), + 'mimeType': modality.mime_type, + } + }) + except lf.ModalityError as e: + raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}') from e + else: + raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}') + return dict(role='user', parts=parts) + + def result(self, json: dict[str, Any]) -> lf.LMSamplingResult: + messages = [ + self._message_from_content_parts(candidate['content']['parts']) + for candidate in json['candidates'] + ] + usage = json['usageMetadata'] + input_tokens = usage['promptTokenCount'] + output_tokens = usage['candidatesTokenCount'] + return lf.LMSamplingResult( + [lf.LMSample(message) for message in messages], + usage=lf.LMSamplingUsage( + prompt_tokens=input_tokens, + completion_tokens=output_tokens, + total_tokens=input_tokens + output_tokens, + estimated_cost=self.estimate_cost( + num_input_tokens=input_tokens, + num_output_tokens=output_tokens, + ), + ), + ) + + def _message_from_content_parts( + self, parts: list[dict[str, Any]] + ) -> lf.Message: + """Converts Vertex AI's content parts protocol to message.""" + chunks = [] + for part in parts: + if text_part := part.get('text'): + chunks.append(text_part) + else: + raise ValueError(f'Unsupported part: {part}') + return lf.AIMessage.from_chunks(chunks) diff --git a/langfun/core/llms/gemini_test.py b/langfun/core/llms/gemini_test.py new file mode 100644 index 0000000..4f82696 --- /dev/null +++ b/langfun/core/llms/gemini_test.py @@ -0,0 +1,190 @@ +# Copyright 2025 The Langfun Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for Gemini API.""" + +import base64 +from typing import Any +import unittest +from unittest import mock + +import langfun.core as lf +from langfun.core import modalities as lf_modalities +from langfun.core.llms import gemini +import pyglove as pg +import requests + + +example_image = ( + b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x18\x00\x00\x00\x18\x04' + b'\x03\x00\x00\x00\x12Y \xcb\x00\x00\x00\x18PLTE\x00\x00' + b'\x00fff_chaag_cg_ch^ci_ciC\xedb\x94\x00\x00\x00\x08tRNS' + b'\x00\n\x9f*\xd4\xff_\xf4\xe4\x8b\xf3a\x00\x00\x00>IDATx' + b'\x01c \x05\x08)"\xd8\xcc\xae!\x06pNz\x88k\x19\\Q\xa8"\x10' + b'\xc1\x14\x95\x01%\xc1\n\xa143Ta\xa8"D-\x84\x03QM\x98\xc3' + b'\x1a\x1a\x1a@5\x0e\x04\xa0q\x88\x05\x00\x07\xf8\x18\xf9' + b'\xdao\xd0|\x00\x00\x00\x00IEND\xaeB`\x82' +) + + +def mock_requests_post(url: str, json: dict[str, Any], **kwargs): + del url, kwargs + c = pg.Dict(json['generationConfig']) + content = json['contents'][0]['parts'][0]['text'] + response = requests.Response() + response.status_code = 200 + response._content = pg.to_json_str({ + 'candidates': [ + { + 'content': { + 'role': 'model', + 'parts': [ + { + 'text': ( + f'This is a response to {content} with ' + f'temperature={c.temperature}, ' + f'top_p={c.topP}, ' + f'top_k={c.topK}, ' + f'max_tokens={c.maxOutputTokens}, ' + f'stop={"".join(c.stopSequences)}.' + ) + }, + ], + }, + }, + ], + 'usageMetadata': { + 'promptTokenCount': 3, + 'candidatesTokenCount': 4, + } + }).encode() + return response + + +class GeminiTest(unittest.TestCase): + """Tests for Vertex model with REST API.""" + + def test_content_from_message_text_only(self): + text = 'This is a beautiful day' + model = gemini.Gemini('gemini-1.5-pro', api_endpoint='') + chunks = model._content_from_message(lf.UserMessage(text)) + self.assertEqual(chunks, {'role': 'user', 'parts': [{'text': text}]}) + + def test_content_from_message_mm(self): + image = lf_modalities.Image.from_bytes(example_image) + message = lf.UserMessage( + 'This is an <<[[image]]>>, what is it?', image=image + ) + + # Non-multimodal model. + with self.assertRaisesRegex(lf.ModalityError, 'Unsupported modality'): + gemini.Gemini( + 'gemini-1.0-pro', api_endpoint='' + )._content_from_message(message) + + model = gemini.Gemini('gemini-1.5-pro', api_endpoint='') + content = model._content_from_message(message) + self.assertEqual( + content, + { + 'role': 'user', + 'parts': [ + {'text': 'This is an'}, + { + 'inlineData': { + 'data': base64.b64encode(example_image).decode(), + 'mimeType': 'image/png', + } + }, + {'text': ', what is it?'}, + ], + }, + ) + + def test_generation_config(self): + model = gemini.Gemini('gemini-1.5-pro', api_endpoint='') + json_schema = { + 'type': 'object', + 'properties': { + 'name': {'type': 'string'}, + }, + 'required': ['name'], + 'title': 'Person', + } + actual = model._generation_config( + lf.UserMessage('hi', json_schema=json_schema), + lf.LMSamplingOptions( + temperature=2.0, + top_p=1.0, + top_k=20, + max_tokens=1024, + stop=['\n'], + ), + ) + self.assertEqual( + actual, + dict( + candidateCount=1, + temperature=2.0, + topP=1.0, + topK=20, + maxOutputTokens=1024, + stopSequences=['\n'], + responseLogprobs=False, + logprobs=None, + seed=None, + responseMimeType='application/json', + responseSchema={ + 'type': 'object', + 'properties': { + 'name': {'type': 'string'} + }, + 'required': ['name'], + 'title': 'Person', + } + ), + ) + with self.assertRaisesRegex( + ValueError, '`json_schema` must be a dict, got' + ): + model._generation_config( + lf.UserMessage('hi', json_schema='not a dict'), + lf.LMSamplingOptions(), + ) + + def test_call_model(self): + with mock.patch('requests.Session.post') as mock_generate: + mock_generate.side_effect = mock_requests_post + + lm = gemini.Gemini('gemini-1.5-pro', api_endpoint='') + r = lm( + 'hello', + temperature=2.0, + top_p=1.0, + top_k=20, + max_tokens=1024, + stop='\n', + ) + self.assertEqual( + r.text, + ( + 'This is a response to hello with temperature=2.0, ' + 'top_p=1.0, top_k=20, max_tokens=1024, stop=\n.' + ), + ) + self.assertEqual(r.metadata.usage.prompt_tokens, 3) + self.assertEqual(r.metadata.usage.completion_tokens, 4) + + +if __name__ == '__main__': + unittest.main() diff --git a/langfun/core/llms/google_genai.py b/langfun/core/llms/google_genai.py index 32af0ee..4b45487 100644 --- a/langfun/core/llms/google_genai.py +++ b/langfun/core/llms/google_genai.py @@ -1,4 +1,4 @@ -# Copyright 2024 The Langfun Authors +# Copyright 2025 The Langfun Authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,57 +11,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Gemini models exposed through Google Generative AI APIs.""" +"""Language models from Google GenAI.""" -import abc -import functools import os -from typing import Annotated, Any, Literal +from typing import Annotated, Literal import langfun.core as lf -from langfun.core import modalities as lf_modalities -from langfun.core.llms import vertexai +from langfun.core.llms import gemini import pyglove as pg -try: - import google.generativeai as genai # pylint: disable=g-import-not-at-top - BlobDict = genai.types.BlobDict - GenerativeModel = genai.GenerativeModel - Completion = getattr(genai.types, 'Completion', Any) - ChatResponse = getattr(genai.types, 'ChatResponse', Any) - GenerateContentResponse = getattr(genai.types, 'GenerateContentResponse', Any) - GenerationConfig = genai.GenerationConfig -except ImportError: - genai = None - BlobDict = Any - GenerativeModel = Any - Completion = Any - ChatResponse = Any - GenerationConfig = Any - GenerateContentResponse = Any - - @lf.use_init_args(['model']) -class GenAI(lf.LanguageModel): +@pg.members([('api_endpoint', pg.typing.Str().freeze(''))]) +class GenAI(gemini.Gemini): """Language models provided by Google GenAI.""" - model: Annotated[ - Literal[ - 'gemini-2.0-flash-thinking-exp-1219', - 'gemini-2.0-flash-exp', - 'gemini-exp-1206', - 'gemini-exp-1114', - 'gemini-1.5-pro-latest', - 'gemini-1.5-flash-latest', - 'gemini-pro', - 'gemini-pro-vision', - 'text-bison-001', - 'chat-bison-001', - ], - 'Model name.', - ] - api_key: Annotated[ str | None, ( @@ -71,26 +35,18 @@ class GenAI(lf.LanguageModel): ), ] = None - supported_modalities: Annotated[ - list[str], - 'A list of MIME types for supported modalities' - ] = [] + api_version: Annotated[ + Literal['v1beta', 'v1alpha'], + 'The API version to use.' + ] = 'v1beta' - # Set the default max concurrency to 8 workers. - max_concurrency = 8 - - def _on_bound(self): - super()._on_bound() - if genai is None: - raise RuntimeError( - 'Please install "langfun[llm-google-genai]" to use ' - 'Google Generative AI models.' - ) - self.__dict__.pop('_api_initialized', None) + @property + def model_id(self) -> str: + """Returns a string to identify the model.""" + return f'GenAI({self.model})' - @functools.cached_property - def _api_initialized(self): - assert genai is not None + @property + def api_endpoint(self) -> str: api_key = self.api_key or os.environ.get('GOOGLE_API_KEY', None) if not api_key: raise ValueError( @@ -100,306 +56,75 @@ def _api_initialized(self): 'https://cloud.google.com/api-keys/docs/create-manage-api-keys ' 'for more details.' ) - genai.configure(api_key=api_key) - return True - - @classmethod - def dir(cls) -> list[str]: - """Lists generative models.""" - assert genai is not None - return [ - m.name.lstrip('models/') - for m in genai.list_models() - if ( - 'generateContent' in m.supported_generation_methods - or 'generateText' in m.supported_generation_methods - or 'generateMessage' in m.supported_generation_methods - ) - ] - - @property - def model_id(self) -> str: - """Returns a string to identify the model.""" - return self.model - - @property - def resource_id(self) -> str: - """Returns a string to identify the resource for rate control.""" - return self.model_id - - def _generation_config(self, options: lf.LMSamplingOptions) -> dict[str, Any]: - """Creates generation config from langfun sampling options.""" - return GenerationConfig( - candidate_count=options.n, - temperature=options.temperature, - top_p=options.top_p, - top_k=options.top_k, - max_output_tokens=options.max_tokens, - stop_sequences=options.stop, + return ( + f'https://generativelanguage.googleapis.com/{self.api_version}' + f'/models/{self.model}:generateContent?' + f'key={api_key}' ) - def _content_from_message( - self, prompt: lf.Message - ) -> list[str | BlobDict]: - """Gets Evergreen formatted content from langfun message.""" - formatted = lf.UserMessage(prompt.text) - formatted.source = prompt - - chunks = [] - for lf_chunk in formatted.chunk(): - if isinstance(lf_chunk, str): - chunks.append(lf_chunk) - elif isinstance(lf_chunk, lf_modalities.Mime): - try: - modalities = lf_chunk.make_compatible( - self.supported_modalities + ['text/plain'] - ) - if isinstance(modalities, lf_modalities.Mime): - modalities = [modalities] - for modality in modalities: - if modality.is_text: - chunk = modality.to_text() - else: - chunk = BlobDict( - data=modality.to_bytes(), - mime_type=modality.mime_type - ) - chunks.append(chunk) - except lf.ModalityError as e: - raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}') from e - else: - raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}') - return chunks - - def _response_to_result( - self, response: GenerateContentResponse | pg.Dict - ) -> lf.LMSamplingResult: - """Parses generative response into message.""" - samples = [] - for candidate in response.candidates: - chunks = [] - for part in candidate.content.parts: - # TODO(daiyip): support multi-modal parts when they are available via - # Gemini API. - if hasattr(part, 'text'): - chunks.append(part.text) - samples.append(lf.LMSample(lf.AIMessage.from_chunks(chunks), score=0.0)) - return lf.LMSamplingResult(samples) - - def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]: - assert self._api_initialized, 'Vertex AI API is not initialized.' - return self._parallel_execute_with_currency_control( - self._sample_single, - prompts, - ) - def _sample_single(self, prompt: lf.Message) -> lf.LMSamplingResult: - """Samples a single prompt.""" - model = _GOOGLE_GENAI_MODEL_HUB.get(self.model) - input_content = self._content_from_message(prompt) - response = model.generate_content( - input_content, - generation_config=self._generation_config(self.sampling_options), - ) - return self._response_to_result(response) - - -class _LegacyGenerativeModel(pg.Object): - """Base for legacy GenAI generative model.""" - - model: str - - def generate_content( - self, - input_content: list[str | BlobDict], - generation_config: GenerationConfig, - ) -> pg.Dict: - """Generate content.""" - segments = [] - for s in input_content: - if not isinstance(s, str): - raise ValueError(f'Unsupported modality: {s!r}') - segments.append(s) - return self.generate(' '.join(segments), generation_config) - - @abc.abstractmethod - def generate( - self, prompt: str, generation_config: GenerationConfig) -> pg.Dict: - """Generate response based on prompt.""" - - -class _LegacyCompletionModel(_LegacyGenerativeModel): - """Legacy GenAI completion model.""" - - def generate( - self, prompt: str, generation_config: GenerationConfig - ) -> pg.Dict: - assert genai is not None - completion: Completion = genai.generate_text( - model=f'models/{self.model}', - prompt=prompt, - temperature=generation_config.temperature, - top_k=generation_config.top_k, - top_p=generation_config.top_p, - candidate_count=generation_config.candidate_count, - max_output_tokens=generation_config.max_output_tokens, - stop_sequences=generation_config.stop_sequences, - ) - return pg.Dict( - candidates=[ - pg.Dict(content=pg.Dict(parts=[pg.Dict(text=c['output'])])) - for c in completion.candidates - ] - ) - - -class _LegacyChatModel(_LegacyGenerativeModel): - """Legacy GenAI chat model.""" - - def generate( - self, prompt: str, generation_config: GenerationConfig - ) -> pg.Dict: - assert genai is not None - response: ChatResponse = genai.chat( - model=f'models/{self.model}', - messages=prompt, - temperature=generation_config.temperature, - top_k=generation_config.top_k, - top_p=generation_config.top_p, - candidate_count=generation_config.candidate_count, - ) - return pg.Dict( - candidates=[ - pg.Dict(content=pg.Dict(parts=[pg.Dict(text=c['content'])])) - for c in response.candidates - ] - ) - - -class _ModelHub: - """Google Generative AI model hub.""" - - def __init__(self): - self._model_cache = {} - - def get( - self, model_name: str - ) -> GenerativeModel | _LegacyGenerativeModel: - """Gets a generative model by model id.""" - assert genai is not None - model = self._model_cache.get(model_name, None) - if model is None: - model_info = genai.get_model(f'models/{model_name}') - if 'generateContent' in model_info.supported_generation_methods: - model = genai.GenerativeModel(model_name) - elif 'generateText' in model_info.supported_generation_methods: - model = _LegacyCompletionModel(model_name) - elif 'generateMessage' in model_info.supported_generation_methods: - model = _LegacyChatModel(model_name) - else: - raise ValueError(f'Unsupported model: {model_name!r}') - self._model_cache[model_name] = model - return model - - -_GOOGLE_GENAI_MODEL_HUB = _ModelHub() - - -# -# Public Gemini models. -# -class GeminiFlash2_0ThinkingExp(GenAI): # pylint: disable=invalid-name - """Gemini 2.0 Flash Thinking Experimental model.""" +class GeminiFlash2_0ThinkingExp_20241219(GenAI): # pylint: disable=invalid-name + """Gemini Flash 2.0 Thinking model launched on 12/19/2024.""" + api_version = 'v1alpha' model = 'gemini-2.0-flash-thinking-exp-1219' - supported_modalities = ( - vertexai.DOCUMENT_TYPES - + vertexai.IMAGE_TYPES - + vertexai.AUDIO_TYPES - + vertexai.VIDEO_TYPES - ) class GeminiFlash2_0Exp(GenAI): # pylint: disable=invalid-name - """Gemini Experimental model launched on 12/06/2024.""" + """Gemini Flash 2.0 model launched on 12/11/2024.""" model = 'gemini-2.0-flash-exp' - supported_modalities = ( - vertexai.DOCUMENT_TYPES - + vertexai.IMAGE_TYPES - + vertexai.AUDIO_TYPES - + vertexai.VIDEO_TYPES - ) class GeminiExp_20241206(GenAI): # pylint: disable=invalid-name """Gemini Experimental model launched on 12/06/2024.""" model = 'gemini-exp-1206' - supported_modalities = ( - vertexai.DOCUMENT_TYPES - + vertexai.IMAGE_TYPES - + vertexai.AUDIO_TYPES - + vertexai.VIDEO_TYPES - ) class GeminiExp_20241114(GenAI): # pylint: disable=invalid-name """Gemini Experimental model launched on 11/14/2024.""" model = 'gemini-exp-1114' - supported_modalities = ( - vertexai.DOCUMENT_TYPES - + vertexai.IMAGE_TYPES - + vertexai.AUDIO_TYPES - + vertexai.VIDEO_TYPES - ) class GeminiPro1_5(GenAI): # pylint: disable=invalid-name """Gemini Pro latest model.""" model = 'gemini-1.5-pro-latest' - supported_modalities = ( - vertexai.DOCUMENT_TYPES - + vertexai.IMAGE_TYPES - + vertexai.AUDIO_TYPES - + vertexai.VIDEO_TYPES - ) -class GeminiFlash1_5(GenAI): # pylint: disable=invalid-name - """Gemini Flash latest model.""" +class GeminiPro1_5_002(GenAI): # pylint: disable=invalid-name + """Gemini Pro latest model.""" + + model = 'gemini-1.5-pro-002' - model = 'gemini-1.5-flash-latest' - supported_modalities = ( - vertexai.DOCUMENT_TYPES - + vertexai.IMAGE_TYPES - + vertexai.AUDIO_TYPES - + vertexai.VIDEO_TYPES - ) +class GeminiPro1_5_001(GenAI): # pylint: disable=invalid-name + """Gemini Pro latest model.""" -class GeminiPro(GenAI): - """Gemini Pro model.""" + model = 'gemini-1.5-pro-001' - model = 'gemini-pro' + +class GeminiFlash1_5(GenAI): # pylint: disable=invalid-name + """Gemini Flash latest model.""" + + model = 'gemini-1.5-flash-latest' -class GeminiProVision(GenAI): - """Gemini Pro vision model.""" +class GeminiFlash1_5_002(GenAI): # pylint: disable=invalid-name + """Gemini Flash 1.5 model stable version 002.""" - model = 'gemini-pro-vision' - supported_modalities = vertexai.IMAGE_TYPES + vertexai.VIDEO_TYPES + model = 'gemini-1.5-flash-002' -class Palm2(GenAI): - """PaLM2 model.""" +class GeminiFlash1_5_001(GenAI): # pylint: disable=invalid-name + """Gemini Flash 1.5 model stable version 001.""" - model = 'text-bison-001' + model = 'gemini-1.5-flash-001' -class Palm2_IT(GenAI): # pylint: disable=invalid-name - """PaLM2 instruction-tuned model.""" +class GeminiPro1(GenAI): # pylint: disable=invalid-name + """Gemini 1.0 Pro model.""" - model = 'chat-bison-001' + model = 'gemini-1.0-pro' diff --git a/langfun/core/llms/google_genai_test.py b/langfun/core/llms/google_genai_test.py index 882f7a8..879e6a6 100644 --- a/langfun/core/llms/google_genai_test.py +++ b/langfun/core/llms/google_genai_test.py @@ -11,223 +11,28 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for Gemini models.""" +"""Tests for Google GenAI models.""" import os import unittest -from unittest import mock - -from google import generativeai as genai -import langfun.core as lf -from langfun.core import modalities as lf_modalities from langfun.core.llms import google_genai -import pyglove as pg - - -example_image = ( - b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x18\x00\x00\x00\x18\x04' - b'\x03\x00\x00\x00\x12Y \xcb\x00\x00\x00\x18PLTE\x00\x00' - b'\x00fff_chaag_cg_ch^ci_ciC\xedb\x94\x00\x00\x00\x08tRNS' - b'\x00\n\x9f*\xd4\xff_\xf4\xe4\x8b\xf3a\x00\x00\x00>IDATx' - b'\x01c \x05\x08)"\xd8\xcc\xae!\x06pNz\x88k\x19\\Q\xa8"\x10' - b'\xc1\x14\x95\x01%\xc1\n\xa143Ta\xa8"D-\x84\x03QM\x98\xc3' - b'\x1a\x1a\x1a@5\x0e\x04\xa0q\x88\x05\x00\x07\xf8\x18\xf9' - b'\xdao\xd0|\x00\x00\x00\x00IEND\xaeB`\x82' -) - - -def mock_get_model(model_name, *args, **kwargs): - del args, kwargs - if 'gemini' in model_name: - method = 'generateContent' - elif 'chat' in model_name: - method = 'generateMessage' - else: - method = 'generateText' - return pg.Dict(supported_generation_methods=[method]) - - -def mock_generate_text(*, model, prompt, **kwargs): - return pg.Dict( - candidates=[pg.Dict(output=f'{prompt} to {model} with {kwargs}')] - ) - - -def mock_chat(*, model, messages, **kwargs): - return pg.Dict( - candidates=[pg.Dict(content=f'{messages} to {model} with {kwargs}')] - ) - - -def mock_generate_content(content, generation_config, **kwargs): - del kwargs - c = generation_config - return genai.types.GenerateContentResponse( - done=True, - iterator=None, - chunks=[], - result=pg.Dict( - prompt_feedback=pg.Dict(block_reason=None), - candidates=[ - pg.Dict( - content=pg.Dict( - parts=[ - pg.Dict( - text=( - f'This is a response to {content[0]} with ' - f'n={c.candidate_count}, ' - f'temperature={c.temperature}, ' - f'top_p={c.top_p}, ' - f'top_k={c.top_k}, ' - f'max_tokens={c.max_output_tokens}, ' - f'stop={c.stop_sequences}.' - ) - ) - ] - ), - ), - ], - ), - ) class GenAITest(unittest.TestCase): - """Tests for Google GenAI model.""" - - def test_content_from_message_text_only(self): - text = 'This is a beautiful day' - model = google_genai.GeminiPro() - chunks = model._content_from_message(lf.UserMessage(text)) - self.assertEqual(chunks, [text]) - - def test_content_from_message_mm(self): - message = lf.UserMessage( - 'This is an <<[[image]]>>, what is it?', - image=lf_modalities.Image.from_bytes(example_image), - ) + """Tests for GenAI model.""" - # Non-multimodal model. - with self.assertRaisesRegex(lf.ModalityError, 'Unsupported modality'): - google_genai.GeminiPro()._content_from_message(message) - - model = google_genai.GeminiProVision() - chunks = model._content_from_message(message) - self.maxDiff = None - self.assertEqual( - chunks, - [ - 'This is an', - genai.types.BlobDict(mime_type='image/png', data=example_image), - ', what is it?', - ], - ) - - def test_response_to_result_text_only(self): - response = genai.types.GenerateContentResponse( - done=True, - iterator=None, - chunks=[], - result=pg.Dict( - prompt_feedback=pg.Dict(block_reason=None), - candidates=[ - pg.Dict( - content=pg.Dict( - parts=[pg.Dict(text='This is response 1.')] - ), - ), - pg.Dict( - content=pg.Dict(parts=[pg.Dict(text='This is response 2.')]) - ), - ], - ), - ) - model = google_genai.GeminiProVision() - result = model._response_to_result(response) - self.assertEqual( - result, - lf.LMSamplingResult([ - lf.LMSample(lf.AIMessage('This is response 1.'), score=0.0), - lf.LMSample(lf.AIMessage('This is response 2.'), score=0.0), - ]), - ) - - def test_model_hub(self): - orig_get_model = genai.get_model - genai.get_model = mock_get_model - - model = google_genai._GOOGLE_GENAI_MODEL_HUB.get('gemini-pro') - self.assertIsNotNone(model) - self.assertIs(google_genai._GOOGLE_GENAI_MODEL_HUB.get('gemini-pro'), model) - - genai.get_model = orig_get_model - - def test_api_key_check(self): + def test_basics(self): with self.assertRaisesRegex(ValueError, 'Please specify `api_key`'): - _ = google_genai.GeminiPro()._api_initialized + _ = google_genai.GeminiPro1_5().api_endpoint + + self.assertIsNotNone(google_genai.GeminiPro1_5(api_key='abc').api_endpoint) - self.assertTrue(google_genai.GeminiPro(api_key='abc')._api_initialized) os.environ['GOOGLE_API_KEY'] = 'abc' - self.assertTrue(google_genai.GeminiPro()._api_initialized) + lm = google_genai.GeminiPro1_5() + self.assertIsNotNone(lm.api_endpoint) + self.assertTrue(lm.model_id.startswith('GenAI(')) del os.environ['GOOGLE_API_KEY'] - def test_call(self): - with mock.patch( - 'google.generativeai.GenerativeModel.generate_content', - ) as mock_generate: - orig_get_model = genai.get_model - genai.get_model = mock_get_model - mock_generate.side_effect = mock_generate_content - - lm = google_genai.GeminiPro(api_key='test_key') - self.maxDiff = None - self.assertEqual( - lm('hello', temperature=2.0, top_k=20, max_tokens=1024).text, - ( - 'This is a response to hello with n=1, temperature=2.0, ' - 'top_p=None, top_k=20, max_tokens=1024, stop=None.' - ), - ) - genai.get_model = orig_get_model - - def test_call_with_legacy_completion_model(self): - orig_get_model = genai.get_model - genai.get_model = mock_get_model - orig_generate_text = getattr(genai, 'generate_text', None) - if orig_generate_text is not None: - genai.generate_text = mock_generate_text - - lm = google_genai.Palm2(api_key='test_key') - self.maxDiff = None - self.assertEqual( - lm('hello', temperature=2.0, top_k=20).text, - ( - "hello to models/text-bison-001 with {'temperature': 2.0, " - "'top_k': 20, 'top_p': None, 'candidate_count': 1, " - "'max_output_tokens': None, 'stop_sequences': None}" - ), - ) - genai.generate_text = orig_generate_text - genai.get_model = orig_get_model - - def test_call_with_legacy_chat_model(self): - orig_get_model = genai.get_model - genai.get_model = mock_get_model - orig_chat = getattr(genai, 'chat', None) - if orig_chat is not None: - genai.chat = mock_chat - - lm = google_genai.Palm2_IT(api_key='test_key') - self.maxDiff = None - self.assertEqual( - lm('hello', temperature=2.0, top_k=20).text, - ( - "hello to models/chat-bison-001 with {'temperature': 2.0, " - "'top_k': 20, 'top_p': None, 'candidate_count': 1}" - ), - ) - genai.chat = orig_chat - genai.get_model = orig_get_model - if __name__ == '__main__': unittest.main() diff --git a/langfun/core/llms/vertexai.py b/langfun/core/llms/vertexai.py index 24aac4c..fd4c698 100644 --- a/langfun/core/llms/vertexai.py +++ b/langfun/core/llms/vertexai.py @@ -13,14 +13,12 @@ # limitations under the License. """Vertex AI generative models.""" -import base64 import functools import os from typing import Annotated, Any import langfun.core as lf -from langfun.core import modalities as lf_modalities -from langfun.core.llms import rest +from langfun.core.llms import gemini import pyglove as pg try: @@ -38,114 +36,11 @@ Credentials = Any -# https://cloud.google.com/vertex-ai/generative-ai/pricing -# describes that the average number of characters per token is about 4. -AVGERAGE_CHARS_PER_TOKEN = 4 - - -# Price in US dollars, -# from https://cloud.google.com/vertex-ai/generative-ai/pricing -# as of 2024-10-10. -SUPPORTED_MODELS_AND_SETTINGS = { - 'gemini-1.5-pro-001': pg.Dict( - rpm=100, - cost_per_1k_input_chars=0.0003125, - cost_per_1k_output_chars=0.00125, - ), - 'gemini-1.5-pro-002': pg.Dict( - rpm=100, - cost_per_1k_input_chars=0.0003125, - cost_per_1k_output_chars=0.00125, - ), - 'gemini-1.5-flash-002': pg.Dict( - rpm=500, - cost_per_1k_input_chars=0.00001875, - cost_per_1k_output_chars=0.000075, - ), - 'gemini-1.5-flash-001': pg.Dict( - rpm=500, - cost_per_1k_input_chars=0.00001875, - cost_per_1k_output_chars=0.000075, - ), - 'gemini-1.5-pro': pg.Dict( - rpm=100, - cost_per_1k_input_chars=0.0003125, - cost_per_1k_output_chars=0.00125, - ), - 'gemini-1.5-flash': pg.Dict( - rpm=500, - cost_per_1k_input_chars=0.00001875, - cost_per_1k_output_chars=0.000075, - ), - 'gemini-1.5-pro-preview-0514': pg.Dict( - rpm=50, - cost_per_1k_input_chars=0.0003125, - cost_per_1k_output_chars=0.00125, - ), - 'gemini-1.5-pro-preview-0409': pg.Dict( - rpm=50, - cost_per_1k_input_chars=0.0003125, - cost_per_1k_output_chars=0.00125, - ), - 'gemini-1.5-flash-preview-0514': pg.Dict( - rpm=200, - cost_per_1k_input_chars=0.00001875, - cost_per_1k_output_chars=0.000075, - ), - 'gemini-1.0-pro': pg.Dict( - rpm=300, - cost_per_1k_input_chars=0.000125, - cost_per_1k_output_chars=0.000375, - ), - 'gemini-1.0-pro-vision': pg.Dict( - rpm=100, - cost_per_1k_input_chars=0.000125, - cost_per_1k_output_chars=0.000375, - ), - # TODO(sharatsharat): Update costs when published - 'gemini-exp-1206': pg.Dict( - rpm=20, - cost_per_1k_input_chars=0.000, - cost_per_1k_output_chars=0.000, - ), - # TODO(sharatsharat): Update costs when published - 'gemini-2.0-flash-exp': pg.Dict( - rpm=10, - cost_per_1k_input_chars=0.000, - cost_per_1k_output_chars=0.000, - ), - # TODO(yifenglu): Update costs when published - 'gemini-2.0-flash-thinking-exp-1219': pg.Dict( - rpm=10, - cost_per_1k_input_chars=0.000, - cost_per_1k_output_chars=0.000, - ), - # TODO(chengrun): Set a more appropriate rpm for endpoint. - 'vertexai-endpoint': pg.Dict( - rpm=20, - cost_per_1k_input_chars=0.0000125, - cost_per_1k_output_chars=0.0000375, - ), -} - - @lf.use_init_args(['model']) @pg.members([('api_endpoint', pg.typing.Str().freeze(''))]) -class VertexAI(rest.REST): +class VertexAI(gemini.Gemini): """Language model served on VertexAI with REST API.""" - model: pg.typing.Annotated[ - pg.typing.Enum( - pg.MISSING_VALUE, list(SUPPORTED_MODELS_AND_SETTINGS.keys()) - ), - ( - 'Vertex AI model name with REST API support. See ' - 'https://cloud.google.com/vertex-ai/generative-ai/docs/' - 'model-reference/inference#supported-models' - ' for details.' - ), - ] - project: Annotated[ str | None, ( @@ -170,11 +65,6 @@ class VertexAI(rest.REST): ), ] = None - supported_modalities: Annotated[ - list[str], - 'A list of MIME types for supported modalities' - ] = [] - def _on_bound(self): super()._on_bound() if google_auth is None: @@ -209,31 +99,9 @@ def _initialize(self): self._credentials = credentials @property - def max_concurrency(self) -> int: - """Returns the maximum number of concurrent requests.""" - return self.rate_to_max_concurrency( - requests_per_min=SUPPORTED_MODELS_AND_SETTINGS[self.model].rpm, - tokens_per_min=0, - ) - - def estimate_cost( - self, - num_input_tokens: int, - num_output_tokens: int - ) -> float | None: - """Estimate the cost based on usage.""" - cost_per_1k_input_chars = SUPPORTED_MODELS_AND_SETTINGS[self.model].get( - 'cost_per_1k_input_chars', None - ) - cost_per_1k_output_chars = SUPPORTED_MODELS_AND_SETTINGS[self.model].get( - 'cost_per_1k_output_chars', None - ) - if cost_per_1k_output_chars is None or cost_per_1k_input_chars is None: - return None - return ( - cost_per_1k_input_chars * num_input_tokens - + cost_per_1k_output_chars * num_output_tokens - ) * AVGERAGE_CHARS_PER_TOKEN / 1000 + def model_id(self) -> str: + """Returns a string to identify the model.""" + return f'VertexAI({self.model})' @functools.cached_property def _session(self): @@ -244,12 +112,6 @@ def _session(self): s.headers.update(self.headers or {}) return s - @property - def headers(self): - return { - 'Content-Type': 'application/json; charset=utf-8', - } - @property def api_endpoint(self) -> str: return ( @@ -258,263 +120,69 @@ def api_endpoint(self) -> str: f'models/{self.model}:generateContent' ) - def request( - self, prompt: lf.Message, sampling_options: lf.LMSamplingOptions - ) -> dict[str, Any]: - request = dict( - generationConfig=self._generation_config(prompt, sampling_options) - ) - request['contents'] = [self._content_from_message(prompt)] - return request - - def _generation_config( - self, prompt: lf.Message, options: lf.LMSamplingOptions - ) -> dict[str, Any]: - """Returns a dict as generation config for prompt and LMSamplingOptions.""" - config = dict( - temperature=options.temperature, - maxOutputTokens=options.max_tokens, - candidateCount=options.n, - topK=options.top_k, - topP=options.top_p, - stopSequences=options.stop, - seed=options.random_seed, - responseLogprobs=options.logprobs, - logprobs=options.top_logprobs, - ) - if json_schema := prompt.metadata.get('json_schema'): - if not isinstance(json_schema, dict): - raise ValueError( - f'`json_schema` must be a dict, got {json_schema!r}.' - ) - json_schema = pg.to_json(json_schema) - config['responseSchema'] = json_schema - config['responseMimeType'] = 'application/json' - prompt.metadata.formatted_text = ( - prompt.text - + '\n\n [RESPONSE FORMAT (not part of prompt)]\n' - + pg.to_json_str(json_schema, json_indent=2) - ) - return config - - def _content_from_message(self, prompt: lf.Message) -> dict[str, Any]: - """Gets generation content from langfun message.""" - parts = [] - for lf_chunk in prompt.chunk(): - if isinstance(lf_chunk, str): - parts.append({'text': lf_chunk}) - elif isinstance(lf_chunk, lf_modalities.Mime): - try: - modalities = lf_chunk.make_compatible( - self.supported_modalities + ['text/plain'] - ) - if isinstance(modalities, lf_modalities.Mime): - modalities = [modalities] - for modality in modalities: - if modality.is_text: - parts.append({'text': modality.to_text()}) - else: - parts.append({ - 'inlineData': { - 'data': base64.b64encode(modality.to_bytes()).decode(), - 'mimeType': modality.mime_type, - } - }) - except lf.ModalityError as e: - raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}') from e - else: - raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}') - return dict(role='user', parts=parts) - - def result(self, json: dict[str, Any]) -> lf.LMSamplingResult: - messages = [ - self._message_from_content_parts(candidate['content']['parts']) - for candidate in json['candidates'] - ] - usage = json['usageMetadata'] - input_tokens = usage['promptTokenCount'] - output_tokens = usage['candidatesTokenCount'] - return lf.LMSamplingResult( - [lf.LMSample(message) for message in messages], - usage=lf.LMSamplingUsage( - prompt_tokens=input_tokens, - completion_tokens=output_tokens, - total_tokens=input_tokens + output_tokens, - estimated_cost=self.estimate_cost( - num_input_tokens=input_tokens, - num_output_tokens=output_tokens, - ), - ), - ) +class VertexAIGeminiFlash2_0ThinkingExp_20241219(VertexAI): # pylint: disable=invalid-name + """Vertex AI Gemini Flash 2.0 Thinking model launched on 12/19/2024.""" - def _message_from_content_parts( - self, parts: list[dict[str, Any]] - ) -> lf.Message: - """Converts Vertex AI's content parts protocol to message.""" - chunks = [] - for part in parts: - if text_part := part.get('text'): - chunks.append(text_part) - else: - raise ValueError(f'Unsupported part: {part}') - return lf.AIMessage.from_chunks(chunks) - - -IMAGE_TYPES = [ - 'image/png', - 'image/jpeg', - 'image/webp', - 'image/heic', - 'image/heif', -] - -AUDIO_TYPES = [ - 'audio/aac', - 'audio/flac', - 'audio/mp3', - 'audio/m4a', - 'audio/mpeg', - 'audio/mpga', - 'audio/mp4', - 'audio/opus', - 'audio/pcm', - 'audio/wav', - 'audio/webm', -] - -VIDEO_TYPES = [ - 'video/mov', - 'video/mpeg', - 'video/mpegps', - 'video/mpg', - 'video/mp4', - 'video/webm', - 'video/wmv', - 'video/x-flv', - 'video/3gpp', - 'video/quicktime', -] - -DOCUMENT_TYPES = [ - 'application/pdf', - 'text/plain', - 'text/csv', - 'text/html', - 'text/xml', - 'text/x-script.python', - 'application/json', -] - - -class VertexAIGemini2_0(VertexAI): # pylint: disable=invalid-name - """Vertex AI Gemini 2.0 model.""" - - supported_modalities: pg.typing.List(str).freeze( # pytype: disable=invalid-annotation - DOCUMENT_TYPES + IMAGE_TYPES + AUDIO_TYPES + VIDEO_TYPES - ) - - -class VertexAIGeminiFlash2_0Exp(VertexAIGemini2_0): # pylint: disable=invalid-name + api_version = 'v1alpha' + model = 'gemini-2.0-flash-thinking-exp-1219' + + +class VertexAIGeminiFlash2_0Exp(VertexAI): # pylint: disable=invalid-name """Vertex AI Gemini 2.0 Flash model.""" model = 'gemini-2.0-flash-exp' -class VertexAIGeminiFlash2_0ThinkingExp(VertexAIGemini2_0): # pylint: disable=invalid-name - """Vertex AI Gemini 2.0 Flash model.""" +class VertexAIGeminiExp_20241206(VertexAI): # pylint: disable=invalid-name + """Vertex AI Gemini Experimental model launched on 12/06/2024.""" - model = 'gemini-2.0-flash-thinking-exp-1219' + model = 'gemini-exp-1206' -class VertexAIGemini1_5(VertexAI): # pylint: disable=invalid-name - """Vertex AI Gemini 1.5 model.""" +class VertexAIGeminiExp_20241114(VertexAI): # pylint: disable=invalid-name + """Vertex AI Gemini Experimental model launched on 11/14/2024.""" - supported_modalities: pg.typing.List(str).freeze( # pytype: disable=invalid-annotation - DOCUMENT_TYPES + IMAGE_TYPES + AUDIO_TYPES + VIDEO_TYPES - ) + model = 'gemini-exp-1114' -class VertexAIGeminiPro1_5(VertexAIGemini1_5): # pylint: disable=invalid-name +class VertexAIGeminiPro1_5(VertexAI): # pylint: disable=invalid-name """Vertex AI Gemini 1.5 Pro model.""" - model = 'gemini-1.5-pro' + model = 'gemini-1.5-pro-latest' -class VertexAIGeminiPro1_5_002(VertexAIGemini1_5): # pylint: disable=invalid-name +class VertexAIGeminiPro1_5_002(VertexAI): # pylint: disable=invalid-name """Vertex AI Gemini 1.5 Pro model.""" model = 'gemini-1.5-pro-002' -class VertexAIGeminiPro1_5_001(VertexAIGemini1_5): # pylint: disable=invalid-name +class VertexAIGeminiPro1_5_001(VertexAI): # pylint: disable=invalid-name """Vertex AI Gemini 1.5 Pro model.""" model = 'gemini-1.5-pro-001' -class VertexAIGeminiPro1_5_0514(VertexAIGemini1_5): # pylint: disable=invalid-name - """Vertex AI Gemini 1.5 Pro preview model.""" - - model = 'gemini-1.5-pro-preview-0514' - - -class VertexAIGeminiPro1_5_0409(VertexAIGemini1_5): # pylint: disable=invalid-name - """Vertex AI Gemini 1.5 Pro preview model.""" - - model = 'gemini-1.5-pro-preview-0409' - - -class VertexAIGeminiFlash1_5(VertexAIGemini1_5): # pylint: disable=invalid-name +class VertexAIGeminiFlash1_5(VertexAI): # pylint: disable=invalid-name """Vertex AI Gemini 1.5 Flash model.""" model = 'gemini-1.5-flash' -class VertexAIGeminiFlash1_5_002(VertexAIGemini1_5): # pylint: disable=invalid-name +class VertexAIGeminiFlash1_5_002(VertexAI): # pylint: disable=invalid-name """Vertex AI Gemini 1.5 Flash model.""" model = 'gemini-1.5-flash-002' -class VertexAIGeminiFlash1_5_001(VertexAIGemini1_5): # pylint: disable=invalid-name +class VertexAIGeminiFlash1_5_001(VertexAI): # pylint: disable=invalid-name """Vertex AI Gemini 1.5 Flash model.""" model = 'gemini-1.5-flash-001' -class VertexAIGeminiFlash1_5_0514(VertexAIGemini1_5): # pylint: disable=invalid-name - """Vertex AI Gemini 1.5 Flash preview model.""" - - model = 'gemini-1.5-flash-preview-0514' - - class VertexAIGeminiPro1(VertexAI): # pylint: disable=invalid-name """Vertex AI Gemini 1.0 Pro model.""" model = 'gemini-1.0-pro' - - -class VertexAIGeminiPro1Vision(VertexAI): # pylint: disable=invalid-name - """Vertex AI Gemini 1.0 Pro Vision model.""" - - model = 'gemini-1.0-pro-vision' - supported_modalities: pg.typing.List(str).freeze( # pytype: disable=invalid-annotation - IMAGE_TYPES + VIDEO_TYPES - ) - - -class VertexAIEndpoint(VertexAI): # pylint: disable=invalid-name - """Vertex AI Endpoint model.""" - - model = 'vertexai-endpoint' - - endpoint: Annotated[str, 'Vertex AI Endpoint ID.'] - - @property - def api_endpoint(self) -> str: - return ( - f'https://{self.location}-aiplatform.googleapis.com/v1/projects/' - f'{self.project}/locations/{self.location}/' - f'endpoints/{self.endpoint}:generateContent' - ) diff --git a/langfun/core/llms/vertexai_test.py b/langfun/core/llms/vertexai_test.py index 3e9d1ca..3e7989e 100644 --- a/langfun/core/llms/vertexai_test.py +++ b/langfun/core/llms/vertexai_test.py @@ -11,105 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for Gemini models.""" +"""Tests for VertexAI models.""" -import base64 import os -from typing import Any import unittest from unittest import mock -import langfun.core as lf -from langfun.core import modalities as lf_modalities from langfun.core.llms import vertexai -import pyglove as pg -import requests - - -example_image = ( - b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x18\x00\x00\x00\x18\x04' - b'\x03\x00\x00\x00\x12Y \xcb\x00\x00\x00\x18PLTE\x00\x00' - b'\x00fff_chaag_cg_ch^ci_ciC\xedb\x94\x00\x00\x00\x08tRNS' - b'\x00\n\x9f*\xd4\xff_\xf4\xe4\x8b\xf3a\x00\x00\x00>IDATx' - b'\x01c \x05\x08)"\xd8\xcc\xae!\x06pNz\x88k\x19\\Q\xa8"\x10' - b'\xc1\x14\x95\x01%\xc1\n\xa143Ta\xa8"D-\x84\x03QM\x98\xc3' - b'\x1a\x1a\x1a@5\x0e\x04\xa0q\x88\x05\x00\x07\xf8\x18\xf9' - b'\xdao\xd0|\x00\x00\x00\x00IEND\xaeB`\x82' -) - - -def mock_requests_post(url: str, json: dict[str, Any], **kwargs): - del url, kwargs - c = pg.Dict(json['generationConfig']) - content = json['contents'][0]['parts'][0]['text'] - response = requests.Response() - response.status_code = 200 - response._content = pg.to_json_str({ - 'candidates': [ - { - 'content': { - 'role': 'model', - 'parts': [ - { - 'text': ( - f'This is a response to {content} with ' - f'temperature={c.temperature}, ' - f'top_p={c.topP}, ' - f'top_k={c.topK}, ' - f'max_tokens={c.maxOutputTokens}, ' - f'stop={"".join(c.stopSequences)}.' - ) - }, - ], - }, - }, - ], - 'usageMetadata': { - 'promptTokenCount': 3, - 'candidatesTokenCount': 4, - } - }).encode() - return response class VertexAITest(unittest.TestCase): """Tests for Vertex model with REST API.""" - def test_content_from_message_text_only(self): - text = 'This is a beautiful day' - model = vertexai.VertexAIGeminiPro1_5_002() - chunks = model._content_from_message(lf.UserMessage(text)) - self.assertEqual(chunks, {'role': 'user', 'parts': [{'text': text}]}) - - def test_content_from_message_mm(self): - image = lf_modalities.Image.from_bytes(example_image) - message = lf.UserMessage( - 'This is an <<[[image]]>>, what is it?', image=image - ) - - # Non-multimodal model. - with self.assertRaisesRegex(lf.ModalityError, 'Unsupported modality'): - vertexai.VertexAIGeminiPro1()._content_from_message(message) - - model = vertexai.VertexAIGeminiPro1Vision() - content = model._content_from_message(message) - self.assertEqual( - content, - { - 'role': 'user', - 'parts': [ - {'text': 'This is an'}, - { - 'inlineData': { - 'data': base64.b64encode(example_image).decode(), - 'mimeType': 'image/png', - } - }, - {'text': ', what is it?'}, - ], - }, - ) - @mock.patch.object(vertexai.VertexAI, 'credentials', new=True) def test_project_and_location_check(self): with self.assertRaisesRegex(ValueError, 'Please specify `project`'): @@ -126,87 +39,14 @@ def test_project_and_location_check(self): os.environ['VERTEXAI_PROJECT'] = 'abc' os.environ['VERTEXAI_LOCATION'] = 'us-central1' - self.assertTrue(vertexai.VertexAIGeminiPro1()._api_initialized) + model = vertexai.VertexAIGeminiPro1() + self.assertTrue(model.model_id.startswith('VertexAI(')) + self.assertIsNotNone(model.api_endpoint) + self.assertTrue(model._api_initialized) + self.assertIsNotNone(model._session) del os.environ['VERTEXAI_PROJECT'] del os.environ['VERTEXAI_LOCATION'] - def test_generation_config(self): - model = vertexai.VertexAIGeminiPro1() - json_schema = { - 'type': 'object', - 'properties': { - 'name': {'type': 'string'}, - }, - 'required': ['name'], - 'title': 'Person', - } - actual = model._generation_config( - lf.UserMessage('hi', json_schema=json_schema), - lf.LMSamplingOptions( - temperature=2.0, - top_p=1.0, - top_k=20, - max_tokens=1024, - stop=['\n'], - ), - ) - self.assertEqual( - actual, - dict( - candidateCount=1, - temperature=2.0, - topP=1.0, - topK=20, - maxOutputTokens=1024, - stopSequences=['\n'], - responseLogprobs=False, - logprobs=None, - seed=None, - responseMimeType='application/json', - responseSchema={ - 'type': 'object', - 'properties': { - 'name': {'type': 'string'} - }, - 'required': ['name'], - 'title': 'Person', - } - ), - ) - with self.assertRaisesRegex( - ValueError, '`json_schema` must be a dict, got' - ): - model._generation_config( - lf.UserMessage('hi', json_schema='not a dict'), - lf.LMSamplingOptions(), - ) - - @mock.patch.object(vertexai.VertexAI, 'credentials', new=True) - def test_call_model(self): - with mock.patch('requests.Session.post') as mock_generate: - mock_generate.side_effect = mock_requests_post - - lm = vertexai.VertexAIGeminiPro1_5_002( - project='abc', location='us-central1' - ) - r = lm( - 'hello', - temperature=2.0, - top_p=1.0, - top_k=20, - max_tokens=1024, - stop='\n', - ) - self.assertEqual( - r.text, - ( - 'This is a response to hello with temperature=2.0, ' - 'top_p=1.0, top_k=20, max_tokens=1024, stop=\n.' - ), - ) - self.assertEqual(r.metadata.usage.prompt_tokens, 3) - self.assertEqual(r.metadata.usage.completion_tokens, 4) - if __name__ == '__main__': unittest.main() diff --git a/requirements.txt b/requirements.txt index 440c8c5..babae31 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,8 +6,8 @@ requests>=2.31.0 termcolor==1.1.0 tqdm>=4.64.1 -# extras:llm-google-genai -google-generativeai>=0.3.2 +# extras:vertexai +google-auth>=2.16.0 # extras:mime-auto python-magic>=0.4.27