From ab27b4c94549fe08a6b1908effa38aedf0a79e81 Mon Sep 17 00:00:00 2001 From: Colton Padden Date: Thu, 7 Nov 2024 09:56:38 -0500 Subject: [PATCH] make ruff --- .../dagster_contrib_notdiamond/resources.py | 12 +- .../test_resources.py | 126 +++++++++++++----- .../example_job/example_notdiamond.py | 30 +++-- 3 files changed, 121 insertions(+), 47 deletions(-) diff --git a/libraries/dagster-contrib-notdiamond/dagster_contrib_notdiamond/resources.py b/libraries/dagster-contrib-notdiamond/dagster_contrib_notdiamond/resources.py index f433fdd..69e1593 100644 --- a/libraries/dagster-contrib-notdiamond/dagster_contrib_notdiamond/resources.py +++ b/libraries/dagster-contrib-notdiamond/dagster_contrib_notdiamond/resources.py @@ -47,7 +47,9 @@ def _add_to_asset_metadata( @public @experimental def with_usage_metadata( - context: Union[AssetExecutionContext, OpExecutionContext], output_name: Optional[str], func + context: Union[AssetExecutionContext, OpExecutionContext], + output_name: Optional[str], + func, ): """This wrapper can be used on any endpoint of the `notdiamond library ` @@ -179,7 +181,9 @@ def notdiamond_asset(context: AssetExecutionContext, nd: NotDiamondResource): ) """ - api_key: str = Field(description=("NotDiamond API key. See https://app.notdiamond.ai/keys")) + api_key: str = Field( + description=("NotDiamond API key. See https://app.notdiamond.ai/keys") + ) _client: NotDiamond = PrivateAttr() @@ -193,7 +197,9 @@ def _wrap_with_usage_metadata( context: AssetExecutionContext, output_name: Optional[str], ): - for attribute_names in API_ENDPOINT_CLASSES_TO_ENDPOINT_METHODS_MAPPING[api_endpoint_class]: + for attribute_names in API_ENDPOINT_CLASSES_TO_ENDPOINT_METHODS_MAPPING[ + api_endpoint_class + ]: curr = self._client.__getattribute__(api_endpoint_class.value) # Get the second to last attribute from the attribute list to reach the method. i = 0 diff --git a/libraries/dagster-contrib-notdiamond/dagster_contrib_notdiamond_tests/test_resources.py b/libraries/dagster-contrib-notdiamond/dagster_contrib_notdiamond_tests/test_resources.py index 25dd7ec..0f0d782 100644 --- a/libraries/dagster-contrib-notdiamond/dagster_contrib_notdiamond_tests/test_resources.py +++ b/libraries/dagster-contrib-notdiamond/dagster_contrib_notdiamond_tests/test_resources.py @@ -23,7 +23,9 @@ @patch("dagster_contrib_notdiamond.resources.NotDiamond") def test_notdiamond_client(mock_client) -> None: - notdiamond_resource = NotDiamondResource(api_key="xoxp-1234123412341234-12341234-1234") + notdiamond_resource = NotDiamondResource( + api_key="xoxp-1234123412341234-12341234-1234" + ) notdiamond_resource.setup_for_execution(build_init_resource_context()) mock_context = MagicMock() @@ -47,7 +49,9 @@ def test_notdiamond_client_with_config(mock_client) -> None: ) -@patch("dagster_contrib_notdiamond.resources.NotDiamondResource._wrap_with_usage_metadata") +@patch( + "dagster_contrib_notdiamond.resources.NotDiamondResource._wrap_with_usage_metadata" +) @patch("dagster.OpExecutionContext", autospec=OpExecutionContext) @patch("dagster_contrib_notdiamond.resources.NotDiamond") def test_notdiamond_resource_with_op(mock_client, mock_context, mock_wrapper): @@ -67,13 +71,17 @@ def notdiamond_op(notdiamond_resource: NotDiamondResource): result = wrap_op_in_graph_and_execute( notdiamond_op, resources={ - "notdiamond_resource": NotDiamondResource(api_key="xoxp-1234123412341234-12341234-1234") + "notdiamond_resource": NotDiamondResource( + api_key="xoxp-1234123412341234-12341234-1234" + ) }, ) assert result.success -@patch("dagster_contrib_notdiamond.resources.NotDiamondResource._wrap_with_usage_metadata") +@patch( + "dagster_contrib_notdiamond.resources.NotDiamondResource._wrap_with_usage_metadata" +) @patch("dagster.AssetExecutionContext", autospec=AssetExecutionContext) @patch("dagster_contrib_notdiamond.resources.NotDiamond") def test_notdiamond_resource_with_asset(mock_client, mock_context, mock_wrapper): @@ -93,17 +101,23 @@ def notdiamond_asset(notdiamond_resource: NotDiamondResource): result = materialize_to_memory( [notdiamond_asset], resources={ - "notdiamond_resource": NotDiamondResource(api_key="xoxp-1234123412341234-12341234-1234") + "notdiamond_resource": NotDiamondResource( + api_key="xoxp-1234123412341234-12341234-1234" + ) }, ) assert result.success -@patch("dagster_contrib_notdiamond.resources.NotDiamondResource._wrap_with_usage_metadata") +@patch( + "dagster_contrib_notdiamond.resources.NotDiamondResource._wrap_with_usage_metadata" +) @patch("dagster.AssetExecutionContext", autospec=AssetExecutionContext) @patch("dagster_contrib_notdiamond.resources.NotDiamond") -def test_notdiamond_resource_with_graph_backed_asset(mock_client, mock_context, mock_wrapper): +def test_notdiamond_resource_with_graph_backed_asset( + mock_client, mock_context, mock_wrapper +): @op def model_version_op(): return ["openai/gpt-4o-mini", "openai/gpt-4o"] @@ -129,14 +143,18 @@ def notdiamond_asset(): result = materialize_to_memory( [notdiamond_asset], resources={ - "notdiamond_resource": NotDiamondResource(api_key="xoxp-1234123412341234-12341234-1234") + "notdiamond_resource": NotDiamondResource( + api_key="xoxp-1234123412341234-12341234-1234" + ) }, ) assert result.success -@patch("dagster_contrib_notdiamond.resources.NotDiamondResource._wrap_with_usage_metadata") +@patch( + "dagster_contrib_notdiamond.resources.NotDiamondResource._wrap_with_usage_metadata" +) @patch("dagster.AssetExecutionContext", autospec=AssetExecutionContext) @patch("dagster_contrib_notdiamond.resources.NotDiamond") def test_notdiamond_resource_with_multi_asset(mock_client, mock_context, mock_wrapper): @@ -181,17 +199,23 @@ def notdiamond_multi_asset(notdiamond_resource: NotDiamondResource): result = materialize_to_memory( [notdiamond_multi_asset], resources={ - "notdiamond_resource": NotDiamondResource(api_key="xoxp-1234123412341234-12341234-1234") + "notdiamond_resource": NotDiamondResource( + api_key="xoxp-1234123412341234-12341234-1234" + ) }, ) assert result.success -@patch("dagster_contrib_notdiamond.resources.NotDiamondResource._wrap_with_usage_metadata") +@patch( + "dagster_contrib_notdiamond.resources.NotDiamondResource._wrap_with_usage_metadata" +) @patch("dagster.AssetExecutionContext", autospec=AssetExecutionContext) @patch("dagster_contrib_notdiamond.resources.NotDiamond") -def test_notdiamond_resource_with_partitioned_asset(mock_client, mock_context, mock_wrapper): +def test_notdiamond_resource_with_partitioned_asset( + mock_client, mock_context, mock_wrapper +): notdiamond_partitions_def = StaticPartitionsDefinition([str(j) for j in range(5)]) notdiamond_partitioned_assets = [] @@ -234,18 +258,20 @@ def notdiamond_partitioned_asset(notdiamond_resource: NotDiamondResource): ) ], resources={ - "notdiamond_resource": NotDiamondResource(api_key="xoxp-1234123412341234-12341234-1234") + "notdiamond_resource": NotDiamondResource( + api_key="xoxp-1234123412341234-12341234-1234" + ) }, ) for partition_key in notdiamond_partitions_def.get_partition_keys(): - result = defs.get_job_def("notdiamond_partitioned_asset_job").execute_in_process( - partition_key=partition_key - ) + result = defs.get_job_def( + "notdiamond_partitioned_asset_job" + ).execute_in_process(partition_key=partition_key) assert result.success - expected_wrapper_call_counts = ( - len(notdiamond_partitioned_assets) * len(notdiamond_partitions_def.get_partition_keys()) + expected_wrapper_call_counts = len(notdiamond_partitioned_assets) * len( + notdiamond_partitions_def.get_partition_keys() ) assert mock_wrapper.call_count == expected_wrapper_call_counts @@ -268,13 +294,17 @@ def notdiamond_op(notdiamond_resource: NotDiamondResource): result = wrap_op_in_graph_and_execute( notdiamond_op, resources={ - "notdiamond_resource": NotDiamondResource(api_key="xoxp-1234123412341234-12341234-1234") + "notdiamond_resource": NotDiamondResource( + api_key="xoxp-1234123412341234-12341234-1234" + ) }, ) assert result.success -@patch("dagster_contrib_notdiamond.resources.NotDiamondResource._wrap_with_usage_metadata") +@patch( + "dagster_contrib_notdiamond.resources.NotDiamondResource._wrap_with_usage_metadata" +) @patch("dagster.AssetExecutionContext", autospec=AssetExecutionContext) @patch("dagster_contrib_notdiamond.resources.NotDiamond") def test_notdiamond_wrapper_with_asset(mock_client, mock_context, mock_wrapper): @@ -297,7 +327,8 @@ def notdiamond_asset(notdiamond_resource: NotDiamondResource): func=client.fine_tuning.jobs.create, ) client.fine_tuning.jobs.create( - model=["openai/gpt-4o-mini", "openai/gpt-4o"], training_file="some_training_file" + model=["openai/gpt-4o-mini", "openai/gpt-4o"], + training_file="some_training_file", ) mock_context.add_output_metadata.assert_called_with( @@ -313,17 +344,23 @@ def notdiamond_asset(notdiamond_resource: NotDiamondResource): result = materialize_to_memory( [notdiamond_asset], resources={ - "notdiamond_resource": NotDiamondResource(api_key="xoxp-1234123412341234-12341234-1234") + "notdiamond_resource": NotDiamondResource( + api_key="xoxp-1234123412341234-12341234-1234" + ) }, ) assert result.success -@patch("dagster_contrib_notdiamond.resources.NotDiamondResource._wrap_with_usage_metadata") +@patch( + "dagster_contrib_notdiamond.resources.NotDiamondResource._wrap_with_usage_metadata" +) @patch("dagster.AssetExecutionContext", autospec=AssetExecutionContext) @patch("dagster_contrib_notdiamond.resources.NotDiamond") -def test_notdiamond_wrapper_with_graph_backed_asset(mock_client, mock_context, mock_wrapper): +def test_notdiamond_wrapper_with_graph_backed_asset( + mock_client, mock_context, mock_wrapper +): @op def model_version_op(): return "openai/gpt-4o-mini" @@ -333,7 +370,9 @@ def training_file_op(): return "some_training_file" @op - def notdiamond_op(notdiamond_resource: NotDiamondResource, model_version, training_file): + def notdiamond_op( + notdiamond_resource: NotDiamondResource, model_version, training_file + ): assert notdiamond_resource mock_completion = MagicMock() @@ -350,7 +389,9 @@ def notdiamond_op(notdiamond_resource: NotDiamondResource, model_version, traini output_name="notdiamond_asset", func=client.fine_tuning.jobs.create, ) - client.fine_tuning.jobs.create(model=model_version, training_file=training_file) + client.fine_tuning.jobs.create( + model=model_version, training_file=training_file + ) mock_context.add_output_metadata.assert_called_with( metadata={ @@ -369,14 +410,18 @@ def notdiamond_asset(): result = materialize_to_memory( [notdiamond_asset], resources={ - "notdiamond_resource": NotDiamondResource(api_key="xoxp-1234123412341234-12341234-1234") + "notdiamond_resource": NotDiamondResource( + api_key="xoxp-1234123412341234-12341234-1234" + ) }, ) assert result.success -@patch("dagster_contrib_notdiamond.resources.NotDiamondResource._wrap_with_usage_metadata") +@patch( + "dagster_contrib_notdiamond.resources.NotDiamondResource._wrap_with_usage_metadata" +) @patch("dagster.AssetExecutionContext", autospec=AssetExecutionContext) @patch("dagster_contrib_notdiamond.resources.NotDiamond") def test_notdiamond_wrapper_with_multi_asset(mock_client, mock_context, mock_wrapper): @@ -403,7 +448,8 @@ def notdiamond_multi_asset(notdiamond_resource: NotDiamondResource): func=client.fine_tuning.jobs.create, ) client.fine_tuning.jobs.create( - model=["openai/gpt-4o-mini", "openai/gpt-4o"], training_file="some_training_file" + model=["openai/gpt-4o-mini", "openai/gpt-4o"], + training_file="some_training_file", ) mock_context.add_output_metadata.assert_called_with( @@ -420,14 +466,18 @@ def notdiamond_multi_asset(notdiamond_resource: NotDiamondResource): result = materialize_to_memory( [notdiamond_multi_asset], resources={ - "notdiamond_resource": NotDiamondResource(api_key="xoxp-1234123412341234-12341234-1234") + "notdiamond_resource": NotDiamondResource( + api_key="xoxp-1234123412341234-12341234-1234" + ) }, ) assert result.success -@patch("dagster_contrib_notdiamond.resources.NotDiamondResource._wrap_with_usage_metadata") +@patch( + "dagster_contrib_notdiamond.resources.NotDiamondResource._wrap_with_usage_metadata" +) @patch("dagster_contrib_notdiamond.resources.NotDiamond") def test_notdiamond_wrapper_with_partitioned_asset(mock_client, mock_wrapper): notdiamond_partitions_def = StaticPartitionsDefinition([str(j) for j in range(5)]) @@ -453,7 +503,9 @@ def notdiamond_partitioned_asset(notdiamond_resource: NotDiamondResource): mock_usage.total_tokens = 1 mock_usage.completion_tokens = 1 mock_completion.usage = mock_usage - mock_client.return_value.fine_tuning.jobs.create.return_value = mock_completion + mock_client.return_value.fine_tuning.jobs.create.return_value = ( + mock_completion + ) with notdiamond_resource.get_client(context=mock_context) as client: client.fine_tuning.jobs.create = with_usage_metadata( @@ -487,12 +539,14 @@ def notdiamond_partitioned_asset(notdiamond_resource: NotDiamondResource): ) ], resources={ - "notdiamond_resource": NotDiamondResource(api_key="xoxp-1234123412341234-12341234-1234") + "notdiamond_resource": NotDiamondResource( + api_key="xoxp-1234123412341234-12341234-1234" + ) }, ) for partition_key in notdiamond_partitions_def.get_partition_keys(): - result = defs.get_job_def("notdiamond_partitioned_asset_job").execute_in_process( - partition_key=partition_key - ) + result = defs.get_job_def( + "notdiamond_partitioned_asset_job" + ).execute_in_process(partition_key=partition_key) assert result.success diff --git a/libraries/dagster-contrib-notdiamond/example_job/example_notdiamond.py b/libraries/dagster-contrib-notdiamond/example_job/example_notdiamond.py index 56f5a37..0030a2d 100644 --- a/libraries/dagster-contrib-notdiamond/example_job/example_notdiamond.py +++ b/libraries/dagster-contrib-notdiamond/example_job/example_notdiamond.py @@ -12,27 +12,39 @@ ) from dagster_contrib_notdiamond import NotDiamondResource + @op -def notdiamond_op(context: OpExecutionContext, notdiamond: NotDiamondResource) -> Tuple[str, str]: +def notdiamond_op( + context: OpExecutionContext, notdiamond: NotDiamondResource +) -> Tuple[str, str]: with notdiamond.get_client(context) as client: session_id, best_llm = client.model_select( model=["openai/gpt-4o", "openai/gpt-4o-mini"], - messages=[{"role": "user", "content": "Say this is a test"}] + messages=[{"role": "user", "content": "Say this is a test"}], ) return session_id, str(best_llm) -notdiamond_op_job = GraphDefinition(name="notdiamond_op_job", node_defs=[notdiamond_op]).to_job() + +notdiamond_op_job = GraphDefinition( + name="notdiamond_op_job", node_defs=[notdiamond_op] +).to_job() + @asset(compute_kind="NotDiamond") -def notdiamond_asset(context: AssetExecutionContext, notdiamond: NotDiamondResource) -> Tuple[str, str]: +def notdiamond_asset( + context: AssetExecutionContext, notdiamond: NotDiamondResource +) -> Tuple[str, str]: with notdiamond.get_client(context) as client: session_id, best_llm = client.model_select( model=["openai/gpt-4o", "openai/gpt-4o-mini"], - messages=[{"role": "user", "content": "Say this is a test"}] + messages=[{"role": "user", "content": "Say this is a test"}], ) return session_id, str(best_llm) -notdiamond_asset_job = define_asset_job(name="notdiamond_asset_job", selection="notdiamond_asset") + +notdiamond_asset_job = define_asset_job( + name="notdiamond_asset_job", selection="notdiamond_asset" +) defs = Definitions( assets=[notdiamond_asset], @@ -44,6 +56,8 @@ def notdiamond_asset(context: AssetExecutionContext, notdiamond: NotDiamondResou if __name__ == "__main__": result = notdiamond_op_job.execute_in_process( - resources={"notdiamond": NotDiamondResource(api_key=EnvVar("NOTDIAMOND_API_KEY"))} + resources={ + "notdiamond": NotDiamondResource(api_key=EnvVar("NOTDIAMOND_API_KEY")) + } ) - print(result.output_for_node("notdiamond_op")) \ No newline at end of file + print(result.output_for_node("notdiamond_op"))