From 95fd53d2921cec04e8c9e71fb49183ac29cf8071 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Mon, 14 Oct 2024 16:09:55 -0700 Subject: [PATCH] registry refactor --- llama_stack/apis/datasets/datasets.py | 14 +++++-- llama_stack/apis/evals/evals.py | 12 +----- .../registry/datasets/__init__.py | 3 +- llama_stack/distribution/registry/registry.py | 38 +++++++++---------- .../distribution/registry/scorers/__init__.py | 5 +-- .../impls/meta_reference/evals/evals.py | 4 ++ .../evals/processor/mmlu_processor.py | 30 +-------------- .../evals/scorer/basic_scorers.py | 4 +- 8 files changed, 39 insertions(+), 71 deletions(-) diff --git a/llama_stack/apis/datasets/datasets.py b/llama_stack/apis/datasets/datasets.py index 11a3f60964..0f4354c3fc 100644 --- a/llama_stack/apis/datasets/datasets.py +++ b/llama_stack/apis/datasets/datasets.py @@ -29,8 +29,7 @@ class GenerationOutput(BaseModel): @json_schema_type class PostprocessedGeneration(BaseModel): completion_message: str - # structured transformed output from raw_completion_message to compute scorer metrics - transformed_generation: Optional[Any] = None + logprobs: Optional[List[TokenLogProbs]] = None # A sample (row) from dataset @@ -70,8 +69,15 @@ class GenerationResponseSample(DatasetSample): @json_schema_type class ScorerInputSample(DatasetSample): - generation_output: PostprocessedGeneration - expected_output: Union[str, List[str]] + """ + A dataset is required to have the following columns to be used for scoring: + - generated_answer: str + - expected_answer: Union[str, List[str]] + """ + + generated_answer: str + expected_answer: Union[str, List[str]] + generation_output: Optional[PostprocessedGeneration] = None @json_schema_type diff --git a/llama_stack/apis/evals/evals.py b/llama_stack/apis/evals/evals.py index af0b291e87..fb3aa6cd4d 100644 --- a/llama_stack/apis/evals/evals.py +++ b/llama_stack/apis/evals/evals.py @@ -217,18 +217,8 @@ def score( class BaseTask(ABC): - def __init__( - self, - generator_processor: Optional[BaseGeneratorProcessor] = None, - generator: Optional[BaseGenerator] = None, - scorer: Optional[BaseScorer] = None, - *args, - **kwargs - ) -> None: + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) - self.generator_processor = generator_processor - self.generator = generator - self.scorer = scorer @abstractmethod async def run(self, *args, **kwargs) -> EvalResult: diff --git a/llama_stack/distribution/registry/datasets/__init__.py b/llama_stack/distribution/registry/datasets/__init__.py index 8164758120..4474c8d7d8 100644 --- a/llama_stack/distribution/registry/datasets/__init__.py +++ b/llama_stack/distribution/registry/datasets/__init__.py @@ -7,5 +7,4 @@ from ..registry import Registry -class DatasetRegistry(Registry[BaseDataset]): - _REGISTRY: Dict[str, BaseDataset] = {} +DatasetRegistry = Registry[BaseDataset]() diff --git a/llama_stack/distribution/registry/registry.py b/llama_stack/distribution/registry/registry.py index 313fb6d4e4..702ed7d869 100644 --- a/llama_stack/distribution/registry/registry.py +++ b/llama_stack/distribution/registry/registry.py @@ -3,36 +3,34 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import AbstractSet, Dict, Generic, TypeVar +from typing import AbstractSet, Generic, TypeVar TRegistry = TypeVar("TRegistry") class Registry(Generic[TRegistry]): - _REGISTRY: Dict[str, TRegistry] = {} - @staticmethod - def names() -> AbstractSet[str]: - return Registry._REGISTRY.keys() + def __init__(self) -> None: + super().__init__() + self.registry = {} - @staticmethod - def register(name: str, task: TRegistry) -> None: - if name in Registry._REGISTRY: + def names(self) -> AbstractSet[str]: + return self.registry.keys() + + def register(self, name: str, task: TRegistry) -> None: + if name in self.registry: raise ValueError(f"Dataset {name} already exists.") - Registry._REGISTRY[name] = task + self.registry[name] = task - @staticmethod - def get(name: str) -> TRegistry: - if name not in Registry._REGISTRY: + def get(self, name: str) -> TRegistry: + if name not in self.registry: raise ValueError(f"Dataset {name} not found.") - return Registry._REGISTRY[name] + return self.registry[name] - @staticmethod - def delete(name: str) -> None: - if name not in Registry._REGISTRY: + def delete(self, name: str) -> None: + if name not in self.registry: raise ValueError(f"Dataset {name} not found.") - del Registry._REGISTRY[name] + del self.registry[name] - @staticmethod - def reset() -> None: - Registry._REGISTRY = {} + def reset(self) -> None: + self.registry = {} diff --git a/llama_stack/distribution/registry/scorers/__init__.py b/llama_stack/distribution/registry/scorers/__init__.py index 084a620a74..dedf32ac3a 100644 --- a/llama_stack/distribution/registry/scorers/__init__.py +++ b/llama_stack/distribution/registry/scorers/__init__.py @@ -9,10 +9,7 @@ from ..registry import Registry - -class ScorerRegistry(Registry[BaseScorer]): - _REGISTRY: Dict[str, BaseScorer] = {} - +ScorerRegistry = Registry[BaseScorer]() SCORER_REGISTRY = { "accuracy": AccuracyScorer, diff --git a/llama_stack/providers/impls/meta_reference/evals/evals.py b/llama_stack/providers/impls/meta_reference/evals/evals.py index 1d703a27ce..abd1938ada 100644 --- a/llama_stack/providers/impls/meta_reference/evals/evals.py +++ b/llama_stack/providers/impls/meta_reference/evals/evals.py @@ -71,6 +71,10 @@ async def run_scorer( dataset_config: EvaluateDatasetConfig, eval_scoring_config: EvaluateScoringConfig, ) -> EvaluateResponse: + cprint("run_scorer") + + # main logic, we need to convert the datset into List[ScorerInputSample] + return EvaluateResponse( eval_result={}, ) diff --git a/llama_stack/providers/impls/meta_reference/evals/processor/mmlu_processor.py b/llama_stack/providers/impls/meta_reference/evals/processor/mmlu_processor.py index 83460bb0c5..fc2d9eb642 100644 --- a/llama_stack/providers/impls/meta_reference/evals/processor/mmlu_processor.py +++ b/llama_stack/providers/impls/meta_reference/evals/processor/mmlu_processor.py @@ -153,35 +153,9 @@ def postprocess_sample( break return ScorerInputSample( + generated_answer=extracted_answer, + expected_answer=dataset_sample.data["Answer"], generation_output=PostprocessedGeneration( completion_message=response_text, - transformed_generation=extracted_answer, ), - expected_output=dataset_sample.data["Answer"], ) - - # def score_sample(self, sample: ProcessedDictSample) -> SingleEvalResult: - # postprocessed_output = sample.postprocessed["postprocessed"] - # expected_answer = sample.data["Answer"] - - # extracted_answer = None - # for answer_regex in MULTILINGUAL_ANSWER_REGEXES: - # regex = MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(answer_regex) - # match = re.search(regex, postprocessed_output) - # if match: - # extracted_answer = normalize_extracted_answer(match.group(1)) - # break - - # score = 1.0 if extracted_answer and extracted_answer == expected_answer else 0.0 - - # return SingleEvalResult( - # score_data={ - # "score": score, - # }, - # ) - - # def aggregate_results(self, eval_results: List[SingleEvalResult]) -> EvalResult: - # print("aggregate_results", eval_results) - # sum_score = sum([result.score_data["score"] for result in eval_results]) - - # return EvalResult(metrics={"score": str(sum_score / len(eval_results))}) diff --git a/llama_stack/providers/impls/meta_reference/evals/scorer/basic_scorers.py b/llama_stack/providers/impls/meta_reference/evals/scorer/basic_scorers.py index 48d8caa3fa..6099353a87 100644 --- a/llama_stack/providers/impls/meta_reference/evals/scorer/basic_scorers.py +++ b/llama_stack/providers/impls/meta_reference/evals/scorer/basic_scorers.py @@ -28,8 +28,8 @@ def aggregate_results(self, eval_results: List[SingleEvalResult]) -> EvalResult: class AccuracyScorer(BaseScorer[ScorerInputSample]): def score_sample(self, scorer_input_sample: ScorerInputSample) -> SingleEvalResult: - extracted_answer = scorer_input_sample.generation_output.transformed_generation - expected_answer = scorer_input_sample.expected_output + extracted_answer = scorer_input_sample.generated_answer + expected_answer = scorer_input_sample.expected_answer accuracy = ( 1.0 if extracted_answer and extracted_answer == expected_answer else 0.0