Skip to content

Commit

Permalink
make ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
cmpadden committed Nov 7, 2024
1 parent 2718d69 commit ab27b4c
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ def _add_to_asset_metadata(
@public
@experimental
def with_usage_metadata(
context: Union[AssetExecutionContext, OpExecutionContext], output_name: Optional[str], func
context: Union[AssetExecutionContext, OpExecutionContext],
output_name: Optional[str],
func,
):
"""This wrapper can be used on any endpoint of the
`notdiamond library <https://github.com/notdiamond/notdiamond-python>`
Expand Down Expand Up @@ -179,7 +181,9 @@ def notdiamond_asset(context: AssetExecutionContext, nd: NotDiamondResource):
)
"""

api_key: str = Field(description=("NotDiamond API key. See https://app.notdiamond.ai/keys"))
api_key: str = Field(
description=("NotDiamond API key. See https://app.notdiamond.ai/keys")
)

_client: NotDiamond = PrivateAttr()

Expand All @@ -193,7 +197,9 @@ def _wrap_with_usage_metadata(
context: AssetExecutionContext,
output_name: Optional[str],
):
for attribute_names in API_ENDPOINT_CLASSES_TO_ENDPOINT_METHODS_MAPPING[api_endpoint_class]:
for attribute_names in API_ENDPOINT_CLASSES_TO_ENDPOINT_METHODS_MAPPING[
api_endpoint_class
]:
curr = self._client.__getattribute__(api_endpoint_class.value)
# Get the second to last attribute from the attribute list to reach the method.
i = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@

@patch("dagster_contrib_notdiamond.resources.NotDiamond")
def test_notdiamond_client(mock_client) -> None:
notdiamond_resource = NotDiamondResource(api_key="xoxp-1234123412341234-12341234-1234")
notdiamond_resource = NotDiamondResource(
api_key="xoxp-1234123412341234-12341234-1234"
)
notdiamond_resource.setup_for_execution(build_init_resource_context())

mock_context = MagicMock()
Expand All @@ -47,7 +49,9 @@ def test_notdiamond_client_with_config(mock_client) -> None:
)


@patch("dagster_contrib_notdiamond.resources.NotDiamondResource._wrap_with_usage_metadata")
@patch(
"dagster_contrib_notdiamond.resources.NotDiamondResource._wrap_with_usage_metadata"
)
@patch("dagster.OpExecutionContext", autospec=OpExecutionContext)
@patch("dagster_contrib_notdiamond.resources.NotDiamond")
def test_notdiamond_resource_with_op(mock_client, mock_context, mock_wrapper):
Expand All @@ -67,13 +71,17 @@ def notdiamond_op(notdiamond_resource: NotDiamondResource):
result = wrap_op_in_graph_and_execute(
notdiamond_op,
resources={
"notdiamond_resource": NotDiamondResource(api_key="xoxp-1234123412341234-12341234-1234")
"notdiamond_resource": NotDiamondResource(
api_key="xoxp-1234123412341234-12341234-1234"
)
},
)
assert result.success


@patch("dagster_contrib_notdiamond.resources.NotDiamondResource._wrap_with_usage_metadata")
@patch(
"dagster_contrib_notdiamond.resources.NotDiamondResource._wrap_with_usage_metadata"
)
@patch("dagster.AssetExecutionContext", autospec=AssetExecutionContext)
@patch("dagster_contrib_notdiamond.resources.NotDiamond")
def test_notdiamond_resource_with_asset(mock_client, mock_context, mock_wrapper):
Expand All @@ -93,17 +101,23 @@ def notdiamond_asset(notdiamond_resource: NotDiamondResource):
result = materialize_to_memory(
[notdiamond_asset],
resources={
"notdiamond_resource": NotDiamondResource(api_key="xoxp-1234123412341234-12341234-1234")
"notdiamond_resource": NotDiamondResource(
api_key="xoxp-1234123412341234-12341234-1234"
)
},
)

assert result.success


@patch("dagster_contrib_notdiamond.resources.NotDiamondResource._wrap_with_usage_metadata")
@patch(
"dagster_contrib_notdiamond.resources.NotDiamondResource._wrap_with_usage_metadata"
)
@patch("dagster.AssetExecutionContext", autospec=AssetExecutionContext)
@patch("dagster_contrib_notdiamond.resources.NotDiamond")
def test_notdiamond_resource_with_graph_backed_asset(mock_client, mock_context, mock_wrapper):
def test_notdiamond_resource_with_graph_backed_asset(
mock_client, mock_context, mock_wrapper
):
@op
def model_version_op():
return ["openai/gpt-4o-mini", "openai/gpt-4o"]
Expand All @@ -129,14 +143,18 @@ def notdiamond_asset():
result = materialize_to_memory(
[notdiamond_asset],
resources={
"notdiamond_resource": NotDiamondResource(api_key="xoxp-1234123412341234-12341234-1234")
"notdiamond_resource": NotDiamondResource(
api_key="xoxp-1234123412341234-12341234-1234"
)
},
)

assert result.success


@patch("dagster_contrib_notdiamond.resources.NotDiamondResource._wrap_with_usage_metadata")
@patch(
"dagster_contrib_notdiamond.resources.NotDiamondResource._wrap_with_usage_metadata"
)
@patch("dagster.AssetExecutionContext", autospec=AssetExecutionContext)
@patch("dagster_contrib_notdiamond.resources.NotDiamond")
def test_notdiamond_resource_with_multi_asset(mock_client, mock_context, mock_wrapper):
Expand Down Expand Up @@ -181,17 +199,23 @@ def notdiamond_multi_asset(notdiamond_resource: NotDiamondResource):
result = materialize_to_memory(
[notdiamond_multi_asset],
resources={
"notdiamond_resource": NotDiamondResource(api_key="xoxp-1234123412341234-12341234-1234")
"notdiamond_resource": NotDiamondResource(
api_key="xoxp-1234123412341234-12341234-1234"
)
},
)

assert result.success


@patch("dagster_contrib_notdiamond.resources.NotDiamondResource._wrap_with_usage_metadata")
@patch(
"dagster_contrib_notdiamond.resources.NotDiamondResource._wrap_with_usage_metadata"
)
@patch("dagster.AssetExecutionContext", autospec=AssetExecutionContext)
@patch("dagster_contrib_notdiamond.resources.NotDiamond")
def test_notdiamond_resource_with_partitioned_asset(mock_client, mock_context, mock_wrapper):
def test_notdiamond_resource_with_partitioned_asset(
mock_client, mock_context, mock_wrapper
):
notdiamond_partitions_def = StaticPartitionsDefinition([str(j) for j in range(5)])

notdiamond_partitioned_assets = []
Expand Down Expand Up @@ -234,18 +258,20 @@ def notdiamond_partitioned_asset(notdiamond_resource: NotDiamondResource):
)
],
resources={
"notdiamond_resource": NotDiamondResource(api_key="xoxp-1234123412341234-12341234-1234")
"notdiamond_resource": NotDiamondResource(
api_key="xoxp-1234123412341234-12341234-1234"
)
},
)

for partition_key in notdiamond_partitions_def.get_partition_keys():
result = defs.get_job_def("notdiamond_partitioned_asset_job").execute_in_process(
partition_key=partition_key
)
result = defs.get_job_def(
"notdiamond_partitioned_asset_job"
).execute_in_process(partition_key=partition_key)
assert result.success

expected_wrapper_call_counts = (
len(notdiamond_partitioned_assets) * len(notdiamond_partitions_def.get_partition_keys())
expected_wrapper_call_counts = len(notdiamond_partitioned_assets) * len(
notdiamond_partitions_def.get_partition_keys()
)
assert mock_wrapper.call_count == expected_wrapper_call_counts

Expand All @@ -268,13 +294,17 @@ def notdiamond_op(notdiamond_resource: NotDiamondResource):
result = wrap_op_in_graph_and_execute(
notdiamond_op,
resources={
"notdiamond_resource": NotDiamondResource(api_key="xoxp-1234123412341234-12341234-1234")
"notdiamond_resource": NotDiamondResource(
api_key="xoxp-1234123412341234-12341234-1234"
)
},
)
assert result.success


@patch("dagster_contrib_notdiamond.resources.NotDiamondResource._wrap_with_usage_metadata")
@patch(
"dagster_contrib_notdiamond.resources.NotDiamondResource._wrap_with_usage_metadata"
)
@patch("dagster.AssetExecutionContext", autospec=AssetExecutionContext)
@patch("dagster_contrib_notdiamond.resources.NotDiamond")
def test_notdiamond_wrapper_with_asset(mock_client, mock_context, mock_wrapper):
Expand All @@ -297,7 +327,8 @@ def notdiamond_asset(notdiamond_resource: NotDiamondResource):
func=client.fine_tuning.jobs.create,
)
client.fine_tuning.jobs.create(
model=["openai/gpt-4o-mini", "openai/gpt-4o"], training_file="some_training_file"
model=["openai/gpt-4o-mini", "openai/gpt-4o"],
training_file="some_training_file",
)

mock_context.add_output_metadata.assert_called_with(
Expand All @@ -313,17 +344,23 @@ def notdiamond_asset(notdiamond_resource: NotDiamondResource):
result = materialize_to_memory(
[notdiamond_asset],
resources={
"notdiamond_resource": NotDiamondResource(api_key="xoxp-1234123412341234-12341234-1234")
"notdiamond_resource": NotDiamondResource(
api_key="xoxp-1234123412341234-12341234-1234"
)
},
)

assert result.success


@patch("dagster_contrib_notdiamond.resources.NotDiamondResource._wrap_with_usage_metadata")
@patch(
"dagster_contrib_notdiamond.resources.NotDiamondResource._wrap_with_usage_metadata"
)
@patch("dagster.AssetExecutionContext", autospec=AssetExecutionContext)
@patch("dagster_contrib_notdiamond.resources.NotDiamond")
def test_notdiamond_wrapper_with_graph_backed_asset(mock_client, mock_context, mock_wrapper):
def test_notdiamond_wrapper_with_graph_backed_asset(
mock_client, mock_context, mock_wrapper
):
@op
def model_version_op():
return "openai/gpt-4o-mini"
Expand All @@ -333,7 +370,9 @@ def training_file_op():
return "some_training_file"

@op
def notdiamond_op(notdiamond_resource: NotDiamondResource, model_version, training_file):
def notdiamond_op(
notdiamond_resource: NotDiamondResource, model_version, training_file
):
assert notdiamond_resource

mock_completion = MagicMock()
Expand All @@ -350,7 +389,9 @@ def notdiamond_op(notdiamond_resource: NotDiamondResource, model_version, traini
output_name="notdiamond_asset",
func=client.fine_tuning.jobs.create,
)
client.fine_tuning.jobs.create(model=model_version, training_file=training_file)
client.fine_tuning.jobs.create(
model=model_version, training_file=training_file
)

mock_context.add_output_metadata.assert_called_with(
metadata={
Expand All @@ -369,14 +410,18 @@ def notdiamond_asset():
result = materialize_to_memory(
[notdiamond_asset],
resources={
"notdiamond_resource": NotDiamondResource(api_key="xoxp-1234123412341234-12341234-1234")
"notdiamond_resource": NotDiamondResource(
api_key="xoxp-1234123412341234-12341234-1234"
)
},
)

assert result.success


@patch("dagster_contrib_notdiamond.resources.NotDiamondResource._wrap_with_usage_metadata")
@patch(
"dagster_contrib_notdiamond.resources.NotDiamondResource._wrap_with_usage_metadata"
)
@patch("dagster.AssetExecutionContext", autospec=AssetExecutionContext)
@patch("dagster_contrib_notdiamond.resources.NotDiamond")
def test_notdiamond_wrapper_with_multi_asset(mock_client, mock_context, mock_wrapper):
Expand All @@ -403,7 +448,8 @@ def notdiamond_multi_asset(notdiamond_resource: NotDiamondResource):
func=client.fine_tuning.jobs.create,
)
client.fine_tuning.jobs.create(
model=["openai/gpt-4o-mini", "openai/gpt-4o"], training_file="some_training_file"
model=["openai/gpt-4o-mini", "openai/gpt-4o"],
training_file="some_training_file",
)

mock_context.add_output_metadata.assert_called_with(
Expand All @@ -420,14 +466,18 @@ def notdiamond_multi_asset(notdiamond_resource: NotDiamondResource):
result = materialize_to_memory(
[notdiamond_multi_asset],
resources={
"notdiamond_resource": NotDiamondResource(api_key="xoxp-1234123412341234-12341234-1234")
"notdiamond_resource": NotDiamondResource(
api_key="xoxp-1234123412341234-12341234-1234"
)
},
)

assert result.success


@patch("dagster_contrib_notdiamond.resources.NotDiamondResource._wrap_with_usage_metadata")
@patch(
"dagster_contrib_notdiamond.resources.NotDiamondResource._wrap_with_usage_metadata"
)
@patch("dagster_contrib_notdiamond.resources.NotDiamond")
def test_notdiamond_wrapper_with_partitioned_asset(mock_client, mock_wrapper):
notdiamond_partitions_def = StaticPartitionsDefinition([str(j) for j in range(5)])
Expand All @@ -453,7 +503,9 @@ def notdiamond_partitioned_asset(notdiamond_resource: NotDiamondResource):
mock_usage.total_tokens = 1
mock_usage.completion_tokens = 1
mock_completion.usage = mock_usage
mock_client.return_value.fine_tuning.jobs.create.return_value = mock_completion
mock_client.return_value.fine_tuning.jobs.create.return_value = (
mock_completion
)

with notdiamond_resource.get_client(context=mock_context) as client:
client.fine_tuning.jobs.create = with_usage_metadata(
Expand Down Expand Up @@ -487,12 +539,14 @@ def notdiamond_partitioned_asset(notdiamond_resource: NotDiamondResource):
)
],
resources={
"notdiamond_resource": NotDiamondResource(api_key="xoxp-1234123412341234-12341234-1234")
"notdiamond_resource": NotDiamondResource(
api_key="xoxp-1234123412341234-12341234-1234"
)
},
)

for partition_key in notdiamond_partitions_def.get_partition_keys():
result = defs.get_job_def("notdiamond_partitioned_asset_job").execute_in_process(
partition_key=partition_key
)
result = defs.get_job_def(
"notdiamond_partitioned_asset_job"
).execute_in_process(partition_key=partition_key)
assert result.success
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,39 @@
)
from dagster_contrib_notdiamond import NotDiamondResource


@op
def notdiamond_op(context: OpExecutionContext, notdiamond: NotDiamondResource) -> Tuple[str, str]:
def notdiamond_op(
context: OpExecutionContext, notdiamond: NotDiamondResource
) -> Tuple[str, str]:
with notdiamond.get_client(context) as client:
session_id, best_llm = client.model_select(
model=["openai/gpt-4o", "openai/gpt-4o-mini"],
messages=[{"role": "user", "content": "Say this is a test"}]
messages=[{"role": "user", "content": "Say this is a test"}],
)
return session_id, str(best_llm)

notdiamond_op_job = GraphDefinition(name="notdiamond_op_job", node_defs=[notdiamond_op]).to_job()

notdiamond_op_job = GraphDefinition(
name="notdiamond_op_job", node_defs=[notdiamond_op]
).to_job()


@asset(compute_kind="NotDiamond")
def notdiamond_asset(context: AssetExecutionContext, notdiamond: NotDiamondResource) -> Tuple[str, str]:
def notdiamond_asset(
context: AssetExecutionContext, notdiamond: NotDiamondResource
) -> Tuple[str, str]:
with notdiamond.get_client(context) as client:
session_id, best_llm = client.model_select(
model=["openai/gpt-4o", "openai/gpt-4o-mini"],
messages=[{"role": "user", "content": "Say this is a test"}]
messages=[{"role": "user", "content": "Say this is a test"}],
)
return session_id, str(best_llm)

notdiamond_asset_job = define_asset_job(name="notdiamond_asset_job", selection="notdiamond_asset")

notdiamond_asset_job = define_asset_job(
name="notdiamond_asset_job", selection="notdiamond_asset"
)

defs = Definitions(
assets=[notdiamond_asset],
Expand All @@ -44,6 +56,8 @@ def notdiamond_asset(context: AssetExecutionContext, notdiamond: NotDiamondResou

if __name__ == "__main__":
result = notdiamond_op_job.execute_in_process(
resources={"notdiamond": NotDiamondResource(api_key=EnvVar("NOTDIAMOND_API_KEY"))}
resources={
"notdiamond": NotDiamondResource(api_key=EnvVar("NOTDIAMOND_API_KEY"))
}
)
print(result.output_for_node("notdiamond_op"))
print(result.output_for_node("notdiamond_op"))

0 comments on commit ab27b4c

Please sign in to comment.