Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LLaMA C++ models and Fix PyType warning/errors raised in VS Code #31

Merged
merged 4 commits into from
Oct 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[flake8]
ignore = E111 # ignore the error of "indentation is not a multiple of 4"
3 changes: 3 additions & 0 deletions langfun/core/llms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@
from langfun.core.llms.openai import Gpt3Babbage
from langfun.core.llms.openai import Gpt3Ada

# LLaMA C++ models.
from langfun.core.llms.llama_cpp import LlamaCppRemote

# Placeholder for Google-internal imports.

# Include cache as sub-module.
Expand Down
80 changes: 80 additions & 0 deletions langfun/core/llms/llama_cpp.py
daiyip marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright 2023 The Langfun Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Language models from llama.cpp."""

from typing import Annotated

import requests

import langfun.core as lf

@lf.use_init_args(['url'])
class LlamaCppRemote(lf.LanguageModel):
"""The remote LLaMA C++ model.

The Remote LLaMA C++ models can be launched via
https://github.com/ggerganov/llama.cpp/tree/master/examples/server
"""

url: Annotated[
str | None,
'The name of the model to use.',
] = None

name: Annotated[
daiyip marked this conversation as resolved.
Show resolved Hide resolved
str,
'The abbreviation for the LLaMA CPP-based model name.',
] = ''

def _on_bound(self):
super()._on_bound()
if not self.url:
raise ValueError('Please specify `url`')

@property
def model_id(self) -> str:
"""Returns a string to identify the model."""
return f'LLaMAC++({self.name})'

def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
def _complete_fn(cur_prompts):
results = []
for prompt in cur_prompts:
result = lf.LMSamplingResult()
for _ in range(self.sampling_options.n or 1):
data = {
"prompt": prompt.text,
"n_predict": self.sampling_options.max_tokens,
"temperature": self.sampling_options.temperature,
"top_k": self.sampling_options.top_k or 50,
"top_p": self.sampling_options.top_p or 0.95,
daiyip marked this conversation as resolved.
Show resolved Hide resolved
}
response = requests.post(
f"{self.url}/completion",
json=data,
headers={"Content-Type": "application/json"},
timeout=self.timeout,
)
decoded_response = response.json()
response = decoded_response["content"]
result.samples.append(lf.LMSample(response, score=0.))
results.append(result)
return results
return lf.with_retry(
_complete_fn,
retry_on_errors=(),
max_attempts=self.max_attempts,
retry_interval=(1, 60),
exponential_backoff=True,
)(prompts)
45 changes: 45 additions & 0 deletions langfun/core/llms/llama_cpp_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright 2023 The Langfun Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for llama cpp models."""

import unittest
from unittest import mock
import typing
import langfun.core as lf
from langfun.core.llms import llama_cpp


def mock_requests_post(url: str, json: typing.Dict[str, typing.Any], **kwargs):
del kwargs
class TEMP:
def json(self):
return {"content": json["prompt"] + "\n" + url}
return TEMP()


class LlamaCppRemoteTest(unittest.TestCase):
"""Tests for the LlamaCppRemote model."""

def test_call_completion(self):
with mock.patch('requests.post') as mock_request:
mock_request.side_effect = mock_requests_post
lm = llama_cpp.LlamaCppRemote(url="http://127.0.0.1:8080")
response = lm('hello', sampling_options=lf.LMSamplingOptions(n=1))
self.assertEqual(
response.text,
'hello\nhttp://127.0.0.1:8080/completion',
)

if __name__ == '__main__':
unittest.main()
26 changes: 16 additions & 10 deletions langfun/core/llms/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,14 @@

import collections
import os
from typing import Annotated, Any, Literal
import langfun.core as lf
from typing import Annotated, Any, Literal, cast

import openai
import pyglove as pg
from openai import openai_object
from openai import error as openai_error

import langfun.core as lf


class Usage(pg.Object):
Expand Down Expand Up @@ -128,7 +132,8 @@ def _get_request_args(
def _sample(self, prompts: list[lf.Message]) -> list[LMSamplingResult]:
if self.is_chat_model:
return self._chat_complete_batch(prompts)
return self._complete_batch(prompts)
else:
return self._complete_batch(prompts)

def _complete_batch(
self, prompts: list[lf.Message]) -> list[LMSamplingResult]:
Expand All @@ -138,7 +143,7 @@ def _open_ai_completion(prompts):
prompt=[p.text for p in prompts],
**self._get_request_args(self.sampling_options),
)

response = cast(openai_object.OpenAIObject, response)
# Parse response.
samples_by_index = collections.defaultdict(list)
for choice in response.choices:
Expand All @@ -161,23 +166,24 @@ def _open_ai_completion(prompts):
return lf.with_retry(
_open_ai_completion,
retry_on_errors=(
openai.error.ServiceUnavailableError,
openai.error.RateLimitError,
openai_error.ServiceUnavailableError,
openai_error.RateLimitError,
),
max_attempts=self.max_attempts,
retry_interval=(1, 60),
exponential_backoff=True,
)(prompts)

def _chat_complete_batch(
self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
self, prompts: list[lf.Message]) -> list[LMSamplingResult]:

def _open_ai_chat_completion(prompt):
response = openai.ChatCompletion.create(
# TODO(daiyip): support conversation history.
# TODO(daiyip): support conversation history and system prompt.
messages=[{'role': 'user', 'content': prompt.text}],
**self._get_request_args(self.sampling_options),
)
response = cast(openai_object.OpenAIObject, response)
return LMSamplingResult(
[
lf.LMSample(choice.message.content, score=0.0)
Expand All @@ -195,8 +201,8 @@ def _open_ai_chat_completion(prompt):
prompts,
max_workers=8,
retry_on_errors=(
openai.error.ServiceUnavailableError,
openai.error.RateLimitError,
openai_error.ServiceUnavailableError,
openai_error.RateLimitError,
),
max_attempts=self.max_attempts,
retry_interval=(1, 60),
Expand Down
2 changes: 1 addition & 1 deletion langfun/core/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def __init__(
sender: str | pg.object_utils.MissingValue = pg.MISSING_VALUE,
metadata: dict[str, Any] | None = None,
tags: list[str] | None = None,
source: 'Message' = None,
source: Optional['Message'] = None,
# The rest are `pg.Object.__init__` arguments.
allow_partial: bool = False,
sealed: bool = False,
Expand Down