Skip to content

Commit

Permalink
Merge pull request #42 from D-X-Y:x
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 572718586
  • Loading branch information
langfun authors committed Oct 11, 2023
2 parents fc65ac3 + aab4e9d commit 5a5c08c
Show file tree
Hide file tree
Showing 6 changed files with 152 additions and 11 deletions.
3 changes: 3 additions & 0 deletions langfun/core/llms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@
from langfun.core.llms.openai import Gpt3Babbage
from langfun.core.llms.openai import Gpt3Ada

# LLaMA C++ models.
from langfun.core.llms.llama_cpp import LlamaCppRemote

# Placeholder for Google-internal imports.

# Include cache as sub-module.
Expand Down
76 changes: 76 additions & 0 deletions langfun/core/llms/llama_cpp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright 2023 The Langfun Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Language models from llama.cpp."""

from typing import Annotated

import langfun.core as lf
import requests


@lf.use_init_args(["url"])
class LlamaCppRemote(lf.LanguageModel):
"""The remote LLaMA C++ model.
The Remote LLaMA C++ models can be launched via
https://github.com/ggerganov/llama.cpp/tree/master/examples/server
"""

url: Annotated[
str,
"The name of the model to use.",
] = ""

name: Annotated[
str,
"The abbreviation for the LLaMA CPP-based model name.",
] = ""

@property
def model_id(self) -> str:
"""Returns a string to identify the model."""
return f"LLaMAC++({self.name})"

def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
def _complete_fn(cur_prompts):
results = []
for prompt in cur_prompts:
result = lf.LMSamplingResult()
for _ in range(self.sampling_options.n or 1):
data = {
"prompt": prompt.text,
"n_predict": self.sampling_options.max_tokens,
"temperature": self.sampling_options.temperature,
"top_k": self.sampling_options.top_k or 50,
"top_p": self.sampling_options.top_p or 0.95,
}
response = requests.post(
f"{self.url}/completion",
json=data,
headers={"Content-Type": "application/json"},
timeout=self.timeout,
)
decoded_response = response.json()
response = decoded_response["content"]
result.samples.append(lf.LMSample(response, score=0.0))
results.append(result)
return results

return lf.with_retry(
_complete_fn,
retry_on_errors=(),
max_attempts=self.max_attempts,
retry_interval=(1, 60),
exponential_backoff=True,
)(prompts)
56 changes: 56 additions & 0 deletions langfun/core/llms/llama_cpp_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright 2023 The Langfun Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for llama cpp models."""

import typing
import unittest
from unittest import mock

import langfun.core as lf
from langfun.core.llms import llama_cpp


def mock_requests_post(url: str, json: typing.Dict[str, typing.Any], **kwargs):
del kwargs

class TEMP:

def json(self):
return {"content": json["prompt"] + "\n" + url}

return TEMP()


class LlamaCppRemoteTest(unittest.TestCase):
"""Tests for the LlamaCppRemote model."""

def test_call_completion(self):
with mock.patch("requests.post") as mock_request:
mock_request.side_effect = mock_requests_post
lm = llama_cpp.LlamaCppRemote(url="http://127.0.0.1:8080")
response = lm("hello", sampling_options=lf.LMSamplingOptions(n=1))
self.assertEqual(
response.text,
"hello\nhttp://127.0.0.1:8080/completion",
)

def test_name(self):
lm = llama_cpp.LlamaCppRemote()
self.assertEqual(lm.model_id, "LLaMAC++()")
lm = llama_cpp.LlamaCppRemote(url="xxx", name="x")
self.assertEqual(lm.model_id, "LLaMAC++(x)")


if __name__ == "__main__":
unittest.main()
25 changes: 15 additions & 10 deletions langfun/core/llms/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@

import collections
import os
from typing import Annotated, Any, Literal
from typing import Annotated, Any, Literal, cast

import langfun.core as lf
import openai
from openai import error as openai_error
from openai import openai_object
import pyglove as pg


Expand Down Expand Up @@ -128,7 +131,8 @@ def _get_request_args(
def _sample(self, prompts: list[lf.Message]) -> list[LMSamplingResult]:
if self.is_chat_model:
return self._chat_complete_batch(prompts)
return self._complete_batch(prompts)
else:
return self._complete_batch(prompts)

def _complete_batch(
self, prompts: list[lf.Message]) -> list[LMSamplingResult]:
Expand All @@ -138,7 +142,7 @@ def _open_ai_completion(prompts):
prompt=[p.text for p in prompts],
**self._get_request_args(self.sampling_options),
)

response = cast(openai_object.OpenAIObject, response)
# Parse response.
samples_by_index = collections.defaultdict(list)
for choice in response.choices:
Expand All @@ -161,23 +165,24 @@ def _open_ai_completion(prompts):
return lf.with_retry(
_open_ai_completion,
retry_on_errors=(
openai.error.ServiceUnavailableError,
openai.error.RateLimitError,
openai_error.ServiceUnavailableError,
openai_error.RateLimitError,
),
max_attempts=self.max_attempts,
retry_interval=(1, 60),
exponential_backoff=True,
)(prompts)

def _chat_complete_batch(
self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:

self, prompts: list[lf.Message]
) -> list[LMSamplingResult]:
def _open_ai_chat_completion(prompt):
response = openai.ChatCompletion.create(
# TODO(daiyip): support conversation history.
# TODO(daiyip): support conversation history and system prompt.
messages=[{'role': 'user', 'content': prompt.text}],
**self._get_request_args(self.sampling_options),
)
response = cast(openai_object.OpenAIObject, response)
return LMSamplingResult(
[
lf.LMSample(choice.message.content, score=0.0)
Expand All @@ -195,8 +200,8 @@ def _open_ai_chat_completion(prompt):
prompts,
max_workers=8,
retry_on_errors=(
openai.error.ServiceUnavailableError,
openai.error.RateLimitError,
openai_error.ServiceUnavailableError,
openai_error.RateLimitError,
),
max_attempts=self.max_attempts,
retry_interval=(1, 60),
Expand Down
2 changes: 1 addition & 1 deletion langfun/core/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def __init__(
sender: str | pg.object_utils.MissingValue = pg.MISSING_VALUE,
metadata: dict[str, Any] | None = None,
tags: list[str] | None = None,
source: 'Message' = None,
source: Optional['Message'] = None,
# The rest are `pg.Object.__init__` arguments.
allow_partial: bool = False,
sealed: bool = False,
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
jinja2>=3.1.2
openai==0.27.2
pyglove>=0.4.4.dev20231009
requests>=2.31.0
termcolor==1.1.0
tqdm>=4.64.1

0 comments on commit 5a5c08c

Please sign in to comment.