Skip to content

Commit

Permalink
[TLM] Add more bedrock models (#344)
Browse files Browse the repository at this point in the history
  • Loading branch information
huiwengoh authored Jan 10, 2025
1 parent e0bff1a commit 3849856
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 10 deletions.
11 changes: 6 additions & 5 deletions cleanlab_studio/internal/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,18 @@
"gpt-4o-mini",
"o1-preview",
"claude-3-haiku",
"claude-3.5-haiku",
"claude-3-sonnet",
"claude-3.5-sonnet",
"claude-3.5-sonnet-v2",
"nova-micro",
"nova-lite",
"nova-pro",
]
_TLM_DEFAULT_MODEL: str = "gpt-4o-mini"
_TLM_MAX_RETRIES: int = 3 # TODO: finalize this number
_TLM_MAX_TOKEN_RANGE: Dict[str, Tuple[int, int]] = { # model: (min, max)
"gpt-3.5-turbo-16k": (64, 4096),
"gpt-4": (64, 4096),
"gpt-4o": (64, 4096),
"gpt-4o-mini": (64, 4096),
"o1-preview": (64, 4096),
"default": (64, 4096),
"claude-3-haiku": (64, 512),
"claude-3-sonnet": (64, 512),
"claude-3.5-sonnet": (64, 512),
Expand Down
5 changes: 3 additions & 2 deletions cleanlab_studio/internal/tlm/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,10 @@ def validate_tlm_options(options: Any) -> None:
raise ValidationError(f"Invalid type {type(val)}, max_tokens must be an integer")

model = options.get("model", _TLM_DEFAULT_MODEL)
if val < _TLM_MAX_TOKEN_RANGE[model][0] or val > _TLM_MAX_TOKEN_RANGE[model][1]:
max_tokens_range = _TLM_MAX_TOKEN_RANGE.get(model, _TLM_MAX_TOKEN_RANGE["default"])
if val < max_tokens_range[0] or val > max_tokens_range[1]:
raise ValidationError(
f"Invalid value {val}, max_tokens for {model} must be in the range {_TLM_MAX_TOKEN_RANGE[model]}"
f"Invalid value {val}, max_tokens for {model} must be in the range {max_tokens_range}"
)

elif option == "model":
Expand Down
2 changes: 1 addition & 1 deletion cleanlab_studio/version.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Note to developers:
# Consider if backend's MIN_CLI_VERSION needs updating when pushing any changes to this file.

__version__ = "2.5.11"
__version__ = "2.5.12"

SCHEMA_VERSION = "0.2.0"
MIN_SCHEMA_VERSION = "0.1.0"
Expand Down
4 changes: 3 additions & 1 deletion tests/tlm/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ def _get_options_dictionary(model: Optional[str]) -> dict:
add_log_perplexity_score = np.random.choice([True, False])

if add_max_tokens:
max_tokens = _TLM_MAX_TOKEN_RANGE[options.get("model", _TLM_DEFAULT_MODEL)][1]
max_tokens = _TLM_MAX_TOKEN_RANGE.get(
options.get("model", _TLM_DEFAULT_MODEL), _TLM_MAX_TOKEN_RANGE["default"]
)[1]
options["max_tokens"] = int(np.random.randint(64, max_tokens))
if add_use_self_reflection:
options["use_self_reflection"] = random.choice([True, False])
Expand Down
14 changes: 13 additions & 1 deletion tests/tlm/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,25 @@
CHARACTERS_PER_TOKEN: int = 4

# Property tests for TLM
excluded_tlm_models: List[str] = ["claude-3-sonnet", "claude-3.5-sonnet", "o1-preview"]
excluded_tlm_models: List[str] = [
"claude-3-sonnet",
"claude-3.5-sonnet",
"claude-3.5-sonnet-v2",
"o1-preview",
"nova-lite",
"nova-pro",
]
VALID_TLM_MODELS: List[str] = [
model for model in _VALID_TLM_MODELS if model not in excluded_tlm_models
]
MODELS_WITH_NO_PERPLEXITY_SCORE: List[str] = [
"claude-3-haiku",
"claude-3.5-haiku",
"claude-3-sonnet",
"claude-3.5-sonnet",
"claude-3.5-sonnet-v2",
"o1-preview",
"nova-micro",
"nova-lite",
"nova-pro",
]

0 comments on commit 3849856

Please sign in to comment.