Skip to content

Commit

Permalink
registry refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
yanxi0830 committed Oct 14, 2024
1 parent c50686b commit 95fd53d
Show file tree
Hide file tree
Showing 8 changed files with 39 additions and 71 deletions.
14 changes: 10 additions & 4 deletions llama_stack/apis/datasets/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ class GenerationOutput(BaseModel):
@json_schema_type
class PostprocessedGeneration(BaseModel):
completion_message: str
# structured transformed output from raw_completion_message to compute scorer metrics
transformed_generation: Optional[Any] = None
logprobs: Optional[List[TokenLogProbs]] = None


# A sample (row) from dataset
Expand Down Expand Up @@ -70,8 +69,15 @@ class GenerationResponseSample(DatasetSample):

@json_schema_type
class ScorerInputSample(DatasetSample):
generation_output: PostprocessedGeneration
expected_output: Union[str, List[str]]
"""
A dataset is required to have the following columns to be used for scoring:
- generated_answer: str
- expected_answer: Union[str, List[str]]
"""

generated_answer: str
expected_answer: Union[str, List[str]]
generation_output: Optional[PostprocessedGeneration] = None


@json_schema_type
Expand Down
12 changes: 1 addition & 11 deletions llama_stack/apis/evals/evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,18 +217,8 @@ def score(


class BaseTask(ABC):
def __init__(
self,
generator_processor: Optional[BaseGeneratorProcessor] = None,
generator: Optional[BaseGenerator] = None,
scorer: Optional[BaseScorer] = None,
*args,
**kwargs
) -> None:
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.generator_processor = generator_processor
self.generator = generator
self.scorer = scorer

@abstractmethod
async def run(self, *args, **kwargs) -> EvalResult:
Expand Down
3 changes: 1 addition & 2 deletions llama_stack/distribution/registry/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,4 @@
from ..registry import Registry


class DatasetRegistry(Registry[BaseDataset]):
_REGISTRY: Dict[str, BaseDataset] = {}
DatasetRegistry = Registry[BaseDataset]()
38 changes: 18 additions & 20 deletions llama_stack/distribution/registry/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,36 +3,34 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import AbstractSet, Dict, Generic, TypeVar
from typing import AbstractSet, Generic, TypeVar

TRegistry = TypeVar("TRegistry")


class Registry(Generic[TRegistry]):
_REGISTRY: Dict[str, TRegistry] = {}

@staticmethod
def names() -> AbstractSet[str]:
return Registry._REGISTRY.keys()
def __init__(self) -> None:
super().__init__()
self.registry = {}

@staticmethod
def register(name: str, task: TRegistry) -> None:
if name in Registry._REGISTRY:
def names(self) -> AbstractSet[str]:
return self.registry.keys()

def register(self, name: str, task: TRegistry) -> None:
if name in self.registry:
raise ValueError(f"Dataset {name} already exists.")
Registry._REGISTRY[name] = task
self.registry[name] = task

@staticmethod
def get(name: str) -> TRegistry:
if name not in Registry._REGISTRY:
def get(self, name: str) -> TRegistry:
if name not in self.registry:
raise ValueError(f"Dataset {name} not found.")
return Registry._REGISTRY[name]
return self.registry[name]

@staticmethod
def delete(name: str) -> None:
if name not in Registry._REGISTRY:
def delete(self, name: str) -> None:
if name not in self.registry:
raise ValueError(f"Dataset {name} not found.")
del Registry._REGISTRY[name]
del self.registry[name]

@staticmethod
def reset() -> None:
Registry._REGISTRY = {}
def reset(self) -> None:
self.registry = {}
5 changes: 1 addition & 4 deletions llama_stack/distribution/registry/scorers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,7 @@

from ..registry import Registry


class ScorerRegistry(Registry[BaseScorer]):
_REGISTRY: Dict[str, BaseScorer] = {}

ScorerRegistry = Registry[BaseScorer]()

SCORER_REGISTRY = {
"accuracy": AccuracyScorer,
Expand Down
4 changes: 4 additions & 0 deletions llama_stack/providers/impls/meta_reference/evals/evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ async def run_scorer(
dataset_config: EvaluateDatasetConfig,
eval_scoring_config: EvaluateScoringConfig,
) -> EvaluateResponse:
cprint("run_scorer")

# main logic, we need to convert the datset into List[ScorerInputSample]

return EvaluateResponse(
eval_result={},
)
Original file line number Diff line number Diff line change
Expand Up @@ -153,35 +153,9 @@ def postprocess_sample(
break

return ScorerInputSample(
generated_answer=extracted_answer,
expected_answer=dataset_sample.data["Answer"],
generation_output=PostprocessedGeneration(
completion_message=response_text,
transformed_generation=extracted_answer,
),
expected_output=dataset_sample.data["Answer"],
)

# def score_sample(self, sample: ProcessedDictSample) -> SingleEvalResult:
# postprocessed_output = sample.postprocessed["postprocessed"]
# expected_answer = sample.data["Answer"]

# extracted_answer = None
# for answer_regex in MULTILINGUAL_ANSWER_REGEXES:
# regex = MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(answer_regex)
# match = re.search(regex, postprocessed_output)
# if match:
# extracted_answer = normalize_extracted_answer(match.group(1))
# break

# score = 1.0 if extracted_answer and extracted_answer == expected_answer else 0.0

# return SingleEvalResult(
# score_data={
# "score": score,
# },
# )

# def aggregate_results(self, eval_results: List[SingleEvalResult]) -> EvalResult:
# print("aggregate_results", eval_results)
# sum_score = sum([result.score_data["score"] for result in eval_results])

# return EvalResult(metrics={"score": str(sum_score / len(eval_results))})
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ def aggregate_results(self, eval_results: List[SingleEvalResult]) -> EvalResult:

class AccuracyScorer(BaseScorer[ScorerInputSample]):
def score_sample(self, scorer_input_sample: ScorerInputSample) -> SingleEvalResult:
extracted_answer = scorer_input_sample.generation_output.transformed_generation
expected_answer = scorer_input_sample.expected_output
extracted_answer = scorer_input_sample.generated_answer
expected_answer = scorer_input_sample.expected_answer

accuracy = (
1.0 if extracted_answer and extracted_answer == expected_answer else 0.0
Expand Down

0 comments on commit 95fd53d

Please sign in to comment.