diff --git a/langfun/core/llms/__init__.py b/langfun/core/llms/__init__.py index df476d2..ace87da 100644 --- a/langfun/core/llms/__init__.py +++ b/langfun/core/llms/__init__.py @@ -46,6 +46,9 @@ from langfun.core.llms.openai import Gpt3Babbage from langfun.core.llms.openai import Gpt3Ada +# LLaMA C++ models. +from langfun.core.llms.llama_cpp import LlamaCppRemote + # Placeholder for Google-internal imports. # Include cache as sub-module. diff --git a/langfun/core/llms/llama_cpp.py b/langfun/core/llms/llama_cpp.py new file mode 100644 index 0000000..276a20e --- /dev/null +++ b/langfun/core/llms/llama_cpp.py @@ -0,0 +1,76 @@ +# Copyright 2023 The Langfun Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Language models from llama.cpp.""" + +from typing import Annotated + +import langfun.core as lf +import requests + + +@lf.use_init_args(["url"]) +class LlamaCppRemote(lf.LanguageModel): + """The remote LLaMA C++ model. + + The Remote LLaMA C++ models can be launched via + https://github.com/ggerganov/llama.cpp/tree/master/examples/server + """ + + url: Annotated[ + str, + "The name of the model to use.", + ] = "" + + name: Annotated[ + str, + "The abbreviation for the LLaMA CPP-based model name.", + ] = "" + + @property + def model_id(self) -> str: + """Returns a string to identify the model.""" + return f"LLaMAC++({self.name})" + + def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]: + def _complete_fn(cur_prompts): + results = [] + for prompt in cur_prompts: + result = lf.LMSamplingResult() + for _ in range(self.sampling_options.n or 1): + data = { + "prompt": prompt.text, + "n_predict": self.sampling_options.max_tokens, + "temperature": self.sampling_options.temperature, + "top_k": self.sampling_options.top_k or 50, + "top_p": self.sampling_options.top_p or 0.95, + } + response = requests.post( + f"{self.url}/completion", + json=data, + headers={"Content-Type": "application/json"}, + timeout=self.timeout, + ) + decoded_response = response.json() + response = decoded_response["content"] + result.samples.append(lf.LMSample(response, score=0.0)) + results.append(result) + return results + + return lf.with_retry( + _complete_fn, + retry_on_errors=(), + max_attempts=self.max_attempts, + retry_interval=(1, 60), + exponential_backoff=True, + )(prompts) diff --git a/langfun/core/llms/llama_cpp_test.py b/langfun/core/llms/llama_cpp_test.py new file mode 100644 index 0000000..fbcf9e6 --- /dev/null +++ b/langfun/core/llms/llama_cpp_test.py @@ -0,0 +1,56 @@ +# Copyright 2023 The Langfun Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for llama cpp models.""" + +import typing +import unittest +from unittest import mock + +import langfun.core as lf +from langfun.core.llms import llama_cpp + + +def mock_requests_post(url: str, json: typing.Dict[str, typing.Any], **kwargs): + del kwargs + + class TEMP: + + def json(self): + return {"content": json["prompt"] + "\n" + url} + + return TEMP() + + +class LlamaCppRemoteTest(unittest.TestCase): + """Tests for the LlamaCppRemote model.""" + + def test_call_completion(self): + with mock.patch("requests.post") as mock_request: + mock_request.side_effect = mock_requests_post + lm = llama_cpp.LlamaCppRemote(url="http://127.0.0.1:8080") + response = lm("hello", sampling_options=lf.LMSamplingOptions(n=1)) + self.assertEqual( + response.text, + "hello\nhttp://127.0.0.1:8080/completion", + ) + + def test_name(self): + lm = llama_cpp.LlamaCppRemote() + self.assertEqual(lm.model_id, "LLaMAC++()") + lm = llama_cpp.LlamaCppRemote(url="xxx", name="x") + self.assertEqual(lm.model_id, "LLaMAC++(x)") + + +if __name__ == "__main__": + unittest.main() diff --git a/langfun/core/llms/openai.py b/langfun/core/llms/openai.py index d9b04c0..cfe47b0 100644 --- a/langfun/core/llms/openai.py +++ b/langfun/core/llms/openai.py @@ -15,9 +15,12 @@ import collections import os -from typing import Annotated, Any, Literal +from typing import Annotated, Any, Literal, cast + import langfun.core as lf import openai +from openai import error as openai_error +from openai import openai_object import pyglove as pg @@ -128,7 +131,8 @@ def _get_request_args( def _sample(self, prompts: list[lf.Message]) -> list[LMSamplingResult]: if self.is_chat_model: return self._chat_complete_batch(prompts) - return self._complete_batch(prompts) + else: + return self._complete_batch(prompts) def _complete_batch( self, prompts: list[lf.Message]) -> list[LMSamplingResult]: @@ -138,7 +142,7 @@ def _open_ai_completion(prompts): prompt=[p.text for p in prompts], **self._get_request_args(self.sampling_options), ) - + response = cast(openai_object.OpenAIObject, response) # Parse response. samples_by_index = collections.defaultdict(list) for choice in response.choices: @@ -161,8 +165,8 @@ def _open_ai_completion(prompts): return lf.with_retry( _open_ai_completion, retry_on_errors=( - openai.error.ServiceUnavailableError, - openai.error.RateLimitError, + openai_error.ServiceUnavailableError, + openai_error.RateLimitError, ), max_attempts=self.max_attempts, retry_interval=(1, 60), @@ -170,14 +174,15 @@ def _open_ai_completion(prompts): )(prompts) def _chat_complete_batch( - self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]: - + self, prompts: list[lf.Message] + ) -> list[LMSamplingResult]: def _open_ai_chat_completion(prompt): response = openai.ChatCompletion.create( - # TODO(daiyip): support conversation history. + # TODO(daiyip): support conversation history and system prompt. messages=[{'role': 'user', 'content': prompt.text}], **self._get_request_args(self.sampling_options), ) + response = cast(openai_object.OpenAIObject, response) return LMSamplingResult( [ lf.LMSample(choice.message.content, score=0.0) @@ -195,8 +200,8 @@ def _open_ai_chat_completion(prompt): prompts, max_workers=8, retry_on_errors=( - openai.error.ServiceUnavailableError, - openai.error.RateLimitError, + openai_error.ServiceUnavailableError, + openai_error.RateLimitError, ), max_attempts=self.max_attempts, retry_interval=(1, 60), diff --git a/langfun/core/message.py b/langfun/core/message.py index e02f678..a98a501 100644 --- a/langfun/core/message.py +++ b/langfun/core/message.py @@ -102,7 +102,7 @@ def __init__( sender: str | pg.object_utils.MissingValue = pg.MISSING_VALUE, metadata: dict[str, Any] | None = None, tags: list[str] | None = None, - source: 'Message' = None, + source: Optional['Message'] = None, # The rest are `pg.Object.__init__` arguments. allow_partial: bool = False, sealed: bool = False, diff --git a/requirements.txt b/requirements.txt index 9340a69..a17206e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ jinja2>=3.1.2 openai==0.27.2 pyglove>=0.4.4.dev20231009 +requests>=2.31.0 termcolor==1.1.0 tqdm>=4.64.1