Skip to content

Commit

Permalink
Update tests in the aggregators package to pass numpy values of the…
Browse files Browse the repository at this point in the history
… correct dtype to computations.

1. Some of these modules are passing tensorflow values to federated computations, instead these should pass numpy values because federated computations will no longer accept tensorflow values.

2. Some of these modules are passing values with the wrong dtype to computations resulting in the need to cast values to which can result in the loss of data.

PiperOrigin-RevId: 658034756
  • Loading branch information
michaelreneer authored and copybara-github committed Jul 31, 2024
1 parent e1cc0c8 commit 4be7cb5
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,11 @@ def test_create_value(self):
('a', computation_types.TensorType(np.int64, [3])),
('b', computation_types.TensorType(np.float32, [])),
])
value_pb, _ = value_serialization.serialize_value(
collections.OrderedDict(a=tf.constant([1, 2, 3]), b=tf.constant(42.0)),
expected_type_spec,
value = collections.OrderedDict(
a=np.array([1, 2, 3], np.int64),
b=np.array(42.0, np.float32),
)
value_pb, _ = value_serialization.serialize_value(value, expected_type_spec)
value = executor.create_value(value_pb)
self.assertIsInstance(value, executor_bindings.OwnedValueId)
# Assert the value ID was incremented.
Expand Down Expand Up @@ -226,7 +227,7 @@ def test_create_struct(self):
executor = get_executor()
expected_type_spec = computation_types.TensorType(np.int64, [3])
value_pb, _ = value_serialization.serialize_value(
tf.constant([1, 2, 3]), expected_type_spec
np.array([1, 2, 3], np.int64), expected_type_spec
)
value = executor.create_value(value_pb)
self.assertEqual(value.ref, 0)
Expand Down Expand Up @@ -264,7 +265,7 @@ def test_create_selection(self):
executor = get_executor()
expected_type_spec = computation_types.TensorType(np.int64, [3])
value_pb, _ = value_serialization.serialize_value(
tf.constant([1, 2, 3]), expected_type_spec
np.array([1, 2, 3], np.int64), expected_type_spec
)
value = executor.create_value(value_pb)
self.assertEqual(value.ref, 0)
Expand Down Expand Up @@ -298,7 +299,7 @@ def test_create_selection(self):
def test_call_with_arg(self):
executor = get_executor()
value_pb, _ = value_serialization.serialize_value(
tf.constant([1, 2, 3]),
np.array([1, 2, 3], np.int64),
computation_types.TensorType(np.int64, [3]),
)
value_ref = executor.create_value(value_pb)
Expand Down
11 changes: 8 additions & 3 deletions tensorflow_federated/python/tests/backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,8 +314,8 @@ def map_foo_at_clients(x):
def map_foo_at_server(x):
return tff.federated_map(foo, x)

bad_tensor = tf.constant([1.0] * 10, dtype=tf.float32)
good_tensor = tf.constant([1.0], dtype=tf.float32)
bad_tensor = np.array([1.0] * 10, dtype=np.float32)
good_tensor = np.array([1.0], dtype=np.float32)
# Ensure running this computation at both placements, or unplaced, still
# raises.
with self.assertRaises(Exception):
Expand Down Expand Up @@ -405,7 +405,12 @@ def count_one_twice():

self.assertEqual((1, 1, 1), count_one_twice())

@tff.test.with_contexts(*test_contexts.get_all_contexts())
@tff.test.with_contexts(
(
'native_sync_local',
tff.backends.native.create_sync_local_cpp_execution_context,
),
)
def test_dynamic_lookup_table(self):

@tff.tensorflow.computation(
Expand Down

0 comments on commit 4be7cb5

Please sign in to comment.