Skip to content

Commit

Permalink
vLLM: Update vLLM-cpu to v0.6.6-post1 (#12728)
Browse files Browse the repository at this point in the history
Update vLLM-cpu to v0.6.6-post1
  • Loading branch information
xiangyuT authored Jan 22, 2025
1 parent 78cca0a commit c9b6c94
Show file tree
Hide file tree
Showing 11 changed files with 2,085 additions and 409 deletions.
9 changes: 6 additions & 3 deletions docker/llm/serving/cpu/docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ RUN wget -qO /sbin/tini https://github.com/krallin/tini/releases/download/${TINI
apt-get update && \
apt-get install -y --no-install-recommends wrk patch g++ && \
pip install --pre --upgrade ipex-llm[serving] && \
apt-get install -y gcc-12 g++-12 libnuma-dev && \
update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12 && \
# Fix Trivy CVE Issues
pip install Jinja2==3.1.3 transformers==4.36.2 gradio==4.19.2 cryptography==42.0.4 && \
# Fix Qwen model adapter in fastchat
Expand All @@ -24,10 +26,11 @@ RUN wget -qO /sbin/tini https://github.com/krallin/tini/releases/download/${TINI
# Install vllm
git clone https://github.com/vllm-project/vllm.git && \
cd ./vllm && \
git checkout v0.4.2 && \
pip install wheel packaging ninja setuptools>=49.4.0 numpy && \
git checkout v0.6.6.post1 && \
pip install cmake>=3.26 wheel packaging ninja "setuptools-scm>=8" numpy && \
pip install -v -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu && \
VLLM_TARGET_DEVICE=cpu python3 setup.py install
VLLM_TARGET_DEVICE=cpu python3 setup.py install && \
pip install ray


COPY ./vllm_offline_inference.py /llm/
Expand Down
1 change: 0 additions & 1 deletion python/llm/src/ipex_llm/transformers/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,7 +693,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
out_features,
mp_group,
None,
None,
optimize_lm_head,
None
)
Expand Down
2 changes: 1 addition & 1 deletion python/llm/src/ipex_llm/transformers/low_bit_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,7 @@ def forward(self, x: torch.Tensor):
dist.inference_all_reduce(result, group=self.mp_group)
if self.bias is not None:
result += self.bias
return result
return result.to(x.dtype)


class FP16Linear(nn.Linear):
Expand Down
3 changes: 2 additions & 1 deletion python/llm/src/ipex_llm/vllm/cpu/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from .engine import IPEXLLMAsyncLLMEngine, IPEXLLMLLMEngine, IPEXLLMClass
from .engine import IPEXLLMAsyncLLMEngine, IPEXLLMLLMEngine, IPEXLLMClass, run_mp_engine
__all__ = [
"IPEXLLMAsyncLLMEngine",
"IPEXLLMLLMEngine",
"IPEXLLMClass",
"run_mp_engine",
]
234 changes: 159 additions & 75 deletions python/llm/src/ipex_llm/vllm/cpu/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#

from typing import List, Optional, Union
from vllm.logger import init_logger
from typing import Dict, Optional, Any, Union, Type
from vllm.engine.llm_engine import LLMEngine
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.entrypoints.llm import LLM
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message)
from vllm.utils import Counter
from vllm.config import VllmConfig
from ipex_llm.vllm.cpu.model_convert import _ipex_llm_convert
from vllm.usage.usage_lib import UsageContext
from vllm.engine.metrics import StatLoggerBase
from vllm.engine.multiprocessing.engine import MQLLMEngine
import signal
from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
TaskOption)
from vllm.config import CompilationConfig
from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
from vllm import envs
from vllm.v1.engine.async_llm import AsyncLLM
import os

from ipex_llm.utils.common import invalidInputError
logger = init_logger(__name__)


class IPEXLLMAsyncLLMEngine(AsyncLLMEngine):
Expand All @@ -35,79 +45,100 @@ def __init__(self, *args, **kwargs):
def from_engine_args(
cls,
engine_args: AsyncEngineArgs,
engine_config: Optional[VllmConfig] = None,
start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
load_in_low_bit: Optional[str] = None,
load_in_low_bit: str = "sym_int4",
stat_loggers: Optional[Dict[str, StatLoggerBase]]=None,
) -> "AsyncLLMEngine":
"""Creates an async LLM engine from the engine arguments."""
# Enable ipex-llm optimizations
engine_config = engine_args.create_engine_config()
from ipex_llm.vllm.cpu.model_convert import _ipex_llm_convert
# Create the engine configs.
_ipex_llm_convert(load_in_low_bit)
if engine_config.device_config.device_type == "neuron":
from vllm.executor.neuron_executor import NeuronExecutorAsync
executor_class = NeuronExecutorAsync
elif engine_config.device_config.device_type == "cpu":
invalidInputError(not engine_config.parallel_config.worker_use_ray, (
"Ray is not supported with the CPU backend."))
from vllm.executor.cpu_executor import CPUExecutorAsync
executor_class = CPUExecutorAsync
elif engine_config.parallel_config.worker_use_ray:
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
executor_class = RayGPUExecutorAsync
else:
invalidInputError(engine_config.parallel_config.world_size == 1, (
"Ray is required if parallel_config.world_size > 1."))
from vllm.executor.gpu_executor import GPUExecutorAsync
executor_class = GPUExecutorAsync
# Create the async LLM engine.
engine = cls(
engine_config.parallel_config.worker_use_ray,
engine_args.engine_use_ray,
**engine_config.to_dict(),
executor_class=executor_class,
log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats,
max_log_len=engine_args.max_log_len,
start_engine_loop=start_engine_loop,
usage_context=usage_context,
)
return engine
return super().from_engine_args(engine_args=engine_args, engine_config=engine_config,
start_engine_loop=start_engine_loop,
usage_context=usage_context, stat_loggers=stat_loggers)


class IPEXLLMClass(LLM):
class IPEXLLMAsyncV1Engine(AsyncLLM):

def __init__(self, *args, **kwargs):
print("IPEX-LLM V1 engine get started...")
super().__init__(*args, **kwargs)

@classmethod
def from_engine_args(
cls,
engine_args: AsyncEngineArgs,
engine_config: Optional[VllmConfig] = None,
start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
load_in_low_bit: str = "sym_int4",
stat_loggers: Optional[Dict[str, StatLoggerBase]]=None,
) -> "AsyncLLM":
_ipex_llm_convert(load_in_low_bit)
return super().from_engine_args(engine_args=engine_args, engine_config=engine_config,
start_engine_loop=start_engine_loop,
usage_context=usage_context, stat_loggers=stat_loggers)


class IPEXLLMClass(LLM):
def __init__(
self,
model: str,
tokenizer: Optional[str] = None,
tokenizer_mode: str = "auto",
skip_tokenizer_init: bool = False,
trust_remote_code: bool = False,
allowed_local_media_path: str = "",
tensor_parallel_size: int = 1,
dtype: str = "auto",
quantization: Optional[str] = None,
revision: Optional[str] = None,
tokenizer_revision: Optional[str] = None,
seed: int = 0,
gpu_memory_utilization: float = 0.9,
swap_space: int = 4,
enforce_eager: bool = False,
max_context_len_to_capture: Optional[int] = None,
swap_space: float = 4,
cpu_offload_gb: float = 0,
enforce_eager: Optional[bool] = None,
max_seq_len_to_capture: int = 8192,
disable_custom_all_reduce: bool = False,
load_in_low_bit: Optional[str] = None,
disable_async_output_proc: bool = True,
hf_overrides: Optional[HfOverrides] = None,
mm_processor_kwargs: Optional[Dict[str, Any]]=None,
# After positional args are removed, move this right below `model`
task: TaskOption = "auto",
override_pooler_config: Optional[PoolerConfig] = None,
compilation_config: Optional[Union[int, Dict[str, Any]]]=None,
load_in_low_bit: str = "sym_int4",
**kwargs,
) -> None:
'''
LLM constructor.
Note: if enforce_eager is unset (enforce_eager is None)
it defaults to False.
'''

if "disable_log_stats" not in kwargs:
kwargs["disable_log_stats"] = True

if compilation_config is not None:
if isinstance(compilation_config, (int, dict)):
compilation_config_instance = CompilationConfig.from_cli(
str(compilation_config))
else:
compilation_config_instance = compilation_config
else:
compilation_config_instance = None

engine_args = EngineArgs(
model=model,
task=task,
tokenizer=tokenizer,
tokenizer_mode=tokenizer_mode,
skip_tokenizer_init=skip_tokenizer_init,
trust_remote_code=trust_remote_code,
allowed_local_media_path=allowed_local_media_path,
tensor_parallel_size=tensor_parallel_size,
dtype=dtype,
quantization=quantization,
Expand All @@ -116,16 +147,60 @@ def __init__(
seed=seed,
gpu_memory_utilization=gpu_memory_utilization,
swap_space=swap_space,
cpu_offload_gb=cpu_offload_gb,
enforce_eager=enforce_eager,
max_context_len_to_capture=max_context_len_to_capture,
max_seq_len_to_capture=max_seq_len_to_capture,
disable_custom_all_reduce=disable_custom_all_reduce,
disable_async_output_proc=disable_async_output_proc,
hf_overrides=hf_overrides,
mm_processor_kwargs=mm_processor_kwargs,
override_pooler_config=override_pooler_config,
compilation_config=compilation_config_instance,
**kwargs,
)
self.llm_engine = IPEXLLMLLMEngine.from_engine_args(engine_args,
load_in_low_bit=load_in_low_bit)
# Logic to switch between engines is done at runtime instead of import
# to avoid import order issues
# TODO(gc): we will need to override this function
self.engine_class = self.get_engine_class()
self.llm_engine = self.engine_class.from_engine_args(
engine_args, usage_context=UsageContext.LLM_CLASS,
load_in_low_bit=load_in_low_bit)

self.request_counter = Counter()

@staticmethod
def get_engine_class() -> Type[LLMEngine]:
if envs.VLLM_USE_V1:
# Lazy import: the v1 package isn't distributed
# from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
return IPEXLLMLLMV1Engine # type: ignore
return IPEXLLMLLMEngine


# TODO(gc): implement this later...
class IPEXLLMLLMV1Engine(V1LLMEngine):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

@classmethod
def from_engine_args(
cls,
engine_args: EngineArgs,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]]=None,
enable_multiprocessing: bool = False,
load_in_low_bit: str = "sym_int4",
) -> "LLMEngine":
"""Creates an LLM engine from the engine arguments."""
# Create the engine configs.

# TODO(gc): delete this later
print("IPEXLLM V1 Engine")
# This does not work as it is in the seperate process...
_ipex_llm_convert(load_in_low_bit)
return super().from_engine_args(engine_args, usage_context,
stat_loggers, enable_multiprocessing)


class IPEXLLMLLMEngine(LLMEngine):
def __init__(self, *args, **kwargs):
Expand All @@ -136,35 +211,44 @@ def from_engine_args(
cls,
engine_args: EngineArgs,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
load_in_low_bit: Optional[str] = None,
stat_loggers: Optional[Dict[str, StatLoggerBase]]=None,
load_in_low_bit: str = "sym_int4",
) -> "LLMEngine":
"""Creates an LLM engine from the engine arguments."""
# Create the engine configs.
engine_config = engine_args.create_engine_config()
from ipex_llm.vllm.cpu.model_convert import _ipex_llm_convert
# TODO(gc): Delete
print("Use vLLM v0 engine")
_ipex_llm_convert(load_in_low_bit)
return super().from_engine_args(engine_args, usage_context, stat_loggers)

# Initialize the cluster and specify the executor class.
if engine_config.device_config.device_type == "neuron":
from vllm.executor.neuron_executor import NeuronExecutor
executor_class = NeuronExecutor
elif engine_config.device_config.device_type == "cpu":
from vllm.executor.cpu_executor import CPUExecutor
executor_class = CPUExecutor
elif engine_config.parallel_config.worker_use_ray:
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_gpu_executor import RayGPUExecutor
executor_class = RayGPUExecutor
else:
invalidInputError(engine_config.parallel_config.world_size == 1, (
"Ray is required if parallel_config.world_size > 1."))
from vllm.executor.gpu_executor import GPUExecutor
executor_class = GPUExecutor

# Create the LLM engine.
engine = cls(**engine_config.to_dict(),
executor_class=executor_class,
log_stats=not engine_args.disable_log_stats,
usage_context=usage_context,
)
return engine

class IPEXLLMMQLLMEngine(MQLLMEngine):
@classmethod
def from_engine_args(cls, engine_args: AsyncEngineArgs,
usage_context: UsageContext, ipc_path: str, load_in_low_bit: str):
_ipex_llm_convert(load_in_low_bit)
return super().from_engine_args(engine_args, usage_context, ipc_path)


def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext,
ipc_path: str, load_in_low_bit: str, engine_alive):

def signal_handler(*_) -> None:
# Interrupt server on sigterm
raise KeyboardInterrupt("MQLLMEngine terminated") # noqa

try:
signal.signal(signal.SIGTERM, signal_handler)

engine = IPEXLLMMQLLMEngine.from_engine_args(engine_args=engine_args,
usage_context=usage_context,
ipc_path=ipc_path,
load_in_low_bit=load_in_low_bit)
engine.start()
except BaseException as e:
logger.exception(e)
engine_alive.value = False
raise e # noqa

if os.getenv("VLLM_USE_V1"):
IPEXLLMAsyncLLMEngine = IPEXLLMAsyncV1Engine
Loading

0 comments on commit c9b6c94

Please sign in to comment.