diff --git a/llama_stack/providers/tests/ci_test_config.yaml b/llama_stack/providers/tests/ci_test_config.yaml index ac135872df..4e47a0ea57 100644 --- a/llama_stack/providers/tests/ci_test_config.yaml +++ b/llama_stack/providers/tests/ci_test_config.yaml @@ -16,8 +16,6 @@ inference_providers: - ollama - fireworks - together - - tgi - - vllm test_models: text: meta-llama/Llama-3.1-8B-Instruct diff --git a/llama_stack/providers/tests/conftest.py b/llama_stack/providers/tests/conftest.py index 09e30264b7..c51f98b0dc 100644 --- a/llama_stack/providers/tests/conftest.py +++ b/llama_stack/providers/tests/conftest.py @@ -20,6 +20,8 @@ from .env import get_env_or_fail +from .report import Report + class ProviderFixture(BaseModel): providers: List[Provider] @@ -61,6 +63,8 @@ def pytest_configure(config): key, value = env_var.split("=", 1) os.environ[key] = value + config.pluginmanager.register(Report(config)) + def pytest_addoption(parser): parser.addoption( diff --git a/llama_stack/providers/tests/report.py b/llama_stack/providers/tests/report.py new file mode 100644 index 0000000000..6cc1734dba --- /dev/null +++ b/llama_stack/providers/tests/report.py @@ -0,0 +1,162 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +from collections import defaultdict + +import pytest +from llama_models.datatypes import CoreModelId + +from pytest_html.basereport import _process_outcome + + +INFERNECE_APIS = ["chat_completion"] +FUNCTIONALITIES = ["streaming", "structured_output", "tool_calling"] +SUPPORTED_MODELS = { + "ollama": [ + CoreModelId.llama3_1_8b_instruct.value, + CoreModelId.llama3_1_8b_instruct.value, + CoreModelId.llama3_1_70b_instruct.value, + CoreModelId.llama3_1_70b_instruct.value, + CoreModelId.llama3_1_405b_instruct.value, + CoreModelId.llama3_1_405b_instruct.value, + CoreModelId.llama3_2_1b_instruct.value, + CoreModelId.llama3_2_1b_instruct.value, + CoreModelId.llama3_2_3b_instruct.value, + CoreModelId.llama3_2_3b_instruct.value, + CoreModelId.llama3_2_11b_vision_instruct.value, + CoreModelId.llama3_2_11b_vision_instruct.value, + CoreModelId.llama3_2_90b_vision_instruct.value, + CoreModelId.llama3_2_90b_vision_instruct.value, + CoreModelId.llama3_3_70b_instruct.value, + CoreModelId.llama_guard_3_8b.value, + CoreModelId.llama_guard_3_1b.value, + ], + "fireworks": [ + CoreModelId.llama3_1_8b_instruct.value, + CoreModelId.llama3_1_70b_instruct.value, + CoreModelId.llama3_1_405b_instruct.value, + CoreModelId.llama3_2_1b_instruct.value, + CoreModelId.llama3_2_3b_instruct.value, + CoreModelId.llama3_2_11b_vision_instruct.value, + CoreModelId.llama3_2_90b_vision_instruct.value, + CoreModelId.llama3_3_70b_instruct.value, + CoreModelId.llama_guard_3_8b.value, + CoreModelId.llama_guard_3_11b_vision.value, + ], + "together": [ + CoreModelId.llama3_1_8b_instruct.value, + CoreModelId.llama3_1_70b_instruct.value, + CoreModelId.llama3_1_405b_instruct.value, + CoreModelId.llama3_2_3b_instruct.value, + CoreModelId.llama3_2_11b_vision_instruct.value, + CoreModelId.llama3_2_90b_vision_instruct.value, + CoreModelId.llama3_3_70b_instruct.value, + CoreModelId.llama_guard_3_8b.value, + CoreModelId.llama_guard_3_11b_vision.value, + ], +} + + +class Report: + + def __init__(self, _config): + self.report_data = defaultdict(dict) + self.test_data = dict() + + @pytest.hookimpl(tryfirst=True) + def pytest_runtest_logreport(self, report): + # This hook is called in several phases, including setup, call and teardown + # The test is considered failed / error if any of the outcomes is not "Passed" + outcome = _process_outcome(report) + if report.nodeid not in self.test_data: + self.test_data[report.nodeid] = outcome + elif self.test_data[report.nodeid] != outcome and outcome != "Passed": + self.test_data[report.nodeid] = outcome + + def pytest_html_results_summary(self, prefix, summary, postfix): + prefix.append("

Inference Providers:

") + for provider in self.report_data.keys(): + prefix.extend( + [ + f"

{ provider }

", + "") + + @pytest.hookimpl(tryfirst=True) + def pytest_runtest_makereport(self, item, call): + if call.when != "setup": + return + # generate the mapping from provider, api/functionality to test nodeid + provider = item.callspec.params.get("inference_stack") + if provider is not None: + api, functionality = self._process_function_name(item.name.split("[")[0]) + + api_test_funcs = self.report_data[provider].get(api, set()) + functionality_test_funcs = self.report_data[provider].get( + functionality, set() + ) + api_test_funcs.add(item.nodeid) + functionality_test_funcs.add(item.nodeid) + self.report_data[provider][api] = api_test_funcs + self.report_data[provider][functionality] = functionality_test_funcs + + def _process_function_name(self, function_name): + api, functionality = None, None + for val in INFERNECE_APIS: + if val in function_name: + api = val + for val in FUNCTIONALITIES: + if val in function_name: + functionality = val + return api, functionality + + def _print_result_icon(self, result): + if result == "Passed": + return "✅" + else: + # result == "Failed" or result == "Error": + return "❌" diff --git a/requirements.txt b/requirements.txt index 304467ddcc..77cb31fdb6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,3 +11,4 @@ requests rich setuptools termcolor +pytest-html